From f9ac1d70b0fc7df286438fa817537c31cb9e7682 Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Mon, 3 Jul 2023 11:41:40 +0200 Subject: [PATCH] feat! use static arrays in ferveo public key serialization --- Cargo.lock | 1 + ferveo-common/Cargo.toml | 1 + ferveo-common/src/keypair.rs | 48 ++++++------ ferveo-common/src/lib.rs | 31 ++++++++ ferveo-python/test/test_serialization.py | 13 +++- ferveo/src/api.rs | 9 ++- ferveo/src/bindings_python.rs | 95 ++++++++++++++---------- ferveo/src/bindings_wasm.rs | 86 +++++++++++++-------- 8 files changed, 188 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3052b20e..224fe203 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -749,6 +749,7 @@ dependencies = [ "ark-serialize", "ark-std", "bincode", + "generic-array", "rand 0.8.5", "rand_core 0.6.4", "serde", diff --git a/ferveo-common/Cargo.toml b/ferveo-common/Cargo.toml index c5b5a58e..d1c788c7 100644 --- a/ferveo-common/Cargo.toml +++ b/ferveo-common/Cargo.toml @@ -11,6 +11,7 @@ ark-ec = "0.4" ark-serialize = { version = "0.4", features = ["derive"] } ark-std = "0.4" bincode = "1.3.3" +generic-array = "0.14.7" rand = "0.8" rand_core = "0.6" serde = { version = "1.0", features = ["derive"] } diff --git a/ferveo-common/src/keypair.rs b/ferveo-common/src/keypair.rs index 70716bfd..485241b3 100644 --- a/ferveo-common/src/keypair.rs +++ b/ferveo-common/src/keypair.rs @@ -6,28 +6,26 @@ use ark_std::{ rand::{prelude::StdRng, RngCore, SeedableRng}, UniformRand, }; -use rand_core::Error; +use generic_array::{typenum::U96, GenericArray}; use serde::*; use serde_with::serde_as; -use crate::serialization; +use crate::{serialization, Error, Result}; // Normally, we would use a custom trait for this, but we can't because // the arkworks will not let us create a blanket implementation for G1Affine // and Fr types. So instead, we're using this shared utility function: -pub fn to_bytes( - item: &T, -) -> Result, ark_serialize::SerializationError> { +pub fn to_bytes(item: &T) -> Result> { let mut writer = Vec::new(); - item.serialize_compressed(&mut writer)?; + item.serialize_compressed(&mut writer) + .map_err(Error::SerializationError)?; Ok(writer) } -pub fn from_bytes( - bytes: &[u8], -) -> Result { +pub fn from_bytes(bytes: &[u8]) -> Result { let mut reader = io::Cursor::new(bytes); - let item = T::deserialize_compressed(&mut reader)?; + let item = T::deserialize_compressed(&mut reader) + .map_err(Error::SerializationError)?; Ok(item) } @@ -39,17 +37,25 @@ pub struct PublicKey { } impl PublicKey { - pub fn to_bytes( - &self, - ) -> Result, ark_serialize::SerializationError> { - to_bytes(&self.encryption_key) + pub fn to_bytes(&self) -> Result> { + let as_bytes = to_bytes(&self.encryption_key)?; + Ok(GenericArray::::from_slice(&as_bytes).to_owned()) } - pub fn from_bytes( - bytes: &[u8], - ) -> Result { - let encryption_key = from_bytes(bytes)?; - Ok(PublicKey:: { encryption_key }) + pub fn from_bytes(bytes: &[u8]) -> Result> { + let bytes = + GenericArray::::from_exact_iter(bytes.iter().cloned()) + .ok_or_else(|| { + Error::InvalidByteLength( + Self::serialized_size(), + bytes.len(), + ) + })?; + from_bytes(&bytes).map(|encryption_key| PublicKey { encryption_key }) + } + + pub fn serialized_size() -> usize { + 96 } } @@ -129,9 +135,9 @@ impl Keypair { 32 } - pub fn from_secure_randomness(bytes: &[u8]) -> Result { + pub fn from_secure_randomness(bytes: &[u8]) -> Result { if bytes.len() != Self::secure_randomness_size() { - return Err(Error::new("Invalid seed length")); + return Err(Error::InvalidSeedLength(bytes.len())); } let mut seed = [0; 32]; seed.copy_from_slice(bytes); diff --git a/ferveo-common/src/lib.rs b/ferveo-common/src/lib.rs index f8420468..c041b6da 100644 --- a/ferveo-common/src/lib.rs +++ b/ferveo-common/src/lib.rs @@ -1,5 +1,36 @@ pub mod keypair; pub mod serialization; +use std::{fmt, fmt::Formatter}; + pub use keypair::*; pub use serialization::*; + +#[derive(Debug)] +pub enum Error { + InvalidByteLength(usize, usize), + SerializationError(ark_serialize::SerializationError), + InvalidSeedLength(usize), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Error::InvalidByteLength(expected, actual) => { + write!( + f, + "Invalid byte length: expected {}, actual {}", + expected, actual + ) + } + Error::SerializationError(e) => { + write!(f, "Serialization error: {}", e) + } + Error::InvalidSeedLength(len) => { + write!(f, "Invalid seed length: {}", len) + } + } + } +} + +type Result = std::result::Result; diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py index 00f800b0..e5de35f0 100644 --- a/ferveo-python/test/test_serialization.py +++ b/ferveo-python/test/test_serialization.py @@ -2,7 +2,8 @@ Keypair, Validator, Dkg, - DkgPublicKey + DkgPublicKey, + FerveoPublicKey, ) @@ -38,6 +39,10 @@ def make_shared_secret(): pass +def make_pk(): + return Keypair.random().public_key() + + # def test_shared_secret_serialization(): # shared_secret = create_shared_secret_instance() # serialized = bytes(shared_secret) @@ -57,3 +62,9 @@ def test_dkg_public_key_serialization(): dkg_pk = make_dkg_public_key() serialized = bytes(dkg_pk) assert len(serialized) == DkgPublicKey.serialized_size() + + +def test_dkg_public_key_serialization(): + pk = make_pk() + serialized = bytes(pk) + assert len(serialized) == FerveoPublicKey.serialized_size() diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index 04995dfe..99c5af02 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -84,7 +84,12 @@ impl DkgPublicKey { pub fn from_bytes(bytes: &[u8]) -> Result { let bytes = GenericArray::::from_exact_iter(bytes.iter().cloned()) - .ok_or(Error::InvalidByteLength(48, bytes.len()))?; + .ok_or_else(|| { + Error::InvalidByteLength( + Self::serialized_size(), + bytes.len(), + ) + })?; from_bytes(&bytes).map(DkgPublicKey) } @@ -198,7 +203,7 @@ impl AggregatedTranscript { shares_num: u32, messages: &[ValidatorMessage], ) -> Result { - let pvss_params = crate::pvss::PubliclyVerifiableParams::::default(); + let pvss_params = PubliclyVerifiableParams::::default(); let domain = Radix2EvaluationDomain::::new(shares_num as usize) .expect("Unable to construct an evaluation domain"); diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index c324e76f..05756164 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -172,7 +172,18 @@ where } } -macro_rules! generate_common_methods { +// TODO: Consider implementing macros to generate following methods + +// fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { +// richcmp(self, other, op) +// } + +// fn __hash__(&self) -> PyResult { +// let bytes = self.0.to_bytes()?; +// hash(stringify!($struct_name), &bytes) +// } + +macro_rules! generate_bytes_serialization { ($struct_name:ident) => { #[pymethods] impl $struct_name { @@ -184,17 +195,35 @@ macro_rules! generate_common_methods { fn __bytes__(&self) -> PyResult { to_py_bytes(&self.0) } + } + }; +} - // TODO: Consider implementing this for all structs - Requires PartialOrd and other traits +macro_rules! generate_boxed_bytes_serialization { + ($struct_name:ident, $inner_struct_name:ident) => { + #[pymethods] + impl $struct_name { + #[staticmethod] + pub fn from_bytes(bytes: &[u8]) -> PyResult { + Ok($struct_name( + $inner_struct_name::from_bytes(bytes).map_err(|err| { + FerveoPythonError::Other(err.to_string()) + })?, + )) + } - // fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { - // richcmp(self, other, op) - // } + fn __bytes__(&self) -> PyResult { + let bytes = self + .0 + .to_bytes() + .map_err(|err| FerveoPythonError::Other(err.to_string()))?; + as_py_bytes(&bytes) + } - // fn __hash__(&self) -> PyResult { - // let bytes = self.0.to_bytes()?; - // hash(stringify!($struct_name), &bytes) - // } + #[staticmethod] + pub fn serialized_size() -> usize { + $inner_struct_name::serialized_size() + } } }; } @@ -253,13 +282,13 @@ pub fn decrypt_with_shared_secret( #[derive(derive_more::AsRef)] pub struct SharedSecret(api::SharedSecret); -generate_common_methods!(SharedSecret); +generate_bytes_serialization!(SharedSecret); #[pyclass(module = "ferveo")] #[derive(derive_more::From, derive_more::AsRef)] pub struct Keypair(api::Keypair); -generate_common_methods!(Keypair); +generate_bytes_serialization!(Keypair); #[pymethods] impl Keypair { @@ -285,13 +314,15 @@ impl Keypair { } } +type InnerPublicKey = api::PublicKey; + #[pyclass(module = "ferveo")] #[derive( Clone, PartialEq, PartialOrd, Eq, derive_more::From, derive_more::AsRef, )] -pub struct FerveoPublicKey(api::PublicKey); +pub struct FerveoPublicKey(InnerPublicKey); -generate_common_methods!(FerveoPublicKey); +generate_boxed_bytes_serialization!(FerveoPublicKey, InnerPublicKey); #[pymethods] impl FerveoPublicKey { @@ -303,7 +334,7 @@ impl FerveoPublicKey { let bytes = self .0 .to_bytes() - .map_err(|err| FerveoPythonError::FerveoError(err.into()))?; + .map_err(|err| FerveoPythonError::Other(err.to_string()))?; hash("FerveoPublicKey", &bytes) } } @@ -339,33 +370,15 @@ impl Validator { #[derive(Clone, derive_more::From, derive_more::AsRef)] pub struct Transcript(api::Transcript); -generate_common_methods!(Transcript); +generate_bytes_serialization!(Transcript); + +type InnerDkgPublicKey = api::DkgPublicKey; #[pyclass(module = "ferveo")] #[derive(Clone, derive_more::From, derive_more::AsRef)] -pub struct DkgPublicKey(api::DkgPublicKey); +pub struct DkgPublicKey(InnerDkgPublicKey); -#[pymethods] -impl DkgPublicKey { - #[staticmethod] - pub fn from_bytes(bytes: &[u8]) -> PyResult { - Ok(Self( - api::DkgPublicKey::from_bytes(bytes) - .map_err(FerveoPythonError::FerveoError)?, - )) - } - - fn __bytes__(&self) -> PyResult { - let bytes = - self.0.to_bytes().map_err(FerveoPythonError::FerveoError)?; - as_py_bytes(&bytes) - } - - #[staticmethod] - pub fn serialized_size() -> usize { - api::DkgPublicKey::serialized_size() - } -} +generate_boxed_bytes_serialization!(DkgPublicKey, InnerDkgPublicKey); #[pyclass(module = "ferveo")] #[derive(derive_more::From, derive_more::AsRef, Clone)] @@ -462,25 +475,25 @@ impl Dkg { )] pub struct Ciphertext(api::Ciphertext); -generate_common_methods!(Ciphertext); +generate_bytes_serialization!(Ciphertext); #[pyclass(module = "ferveo")] #[derive(Clone, derive_more::AsRef, derive_more::From)] pub struct DecryptionShareSimple(api::DecryptionShareSimple); -generate_common_methods!(DecryptionShareSimple); +generate_bytes_serialization!(DecryptionShareSimple); #[pyclass(module = "ferveo")] #[derive(Clone, derive_more::AsRef, derive_more::From)] pub struct DecryptionSharePrecomputed(api::DecryptionSharePrecomputed); -generate_common_methods!(DecryptionSharePrecomputed); +generate_bytes_serialization!(DecryptionSharePrecomputed); #[pyclass(module = "ferveo")] #[derive(derive_more::From, derive_more::AsRef)] pub struct AggregatedTranscript(api::AggregatedTranscript); -generate_common_methods!(AggregatedTranscript); +generate_bytes_serialization!(AggregatedTranscript); #[pymethods] impl AggregatedTranscript { diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index e4b976a3..8e071564 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -99,7 +99,19 @@ fn unwrap_messages_js( Ok(messages) } -macro_rules! generate_common_methods { +macro_rules! generate_equals { + ($struct_name:ident) => { + #[wasm_bindgen] + impl $struct_name { + #[wasm_bindgen] + pub fn equals(&self, other: &$struct_name) -> bool { + self.0 == other.0 + } + } + }; +} + +macro_rules! generate_bytes_serialization { ($struct_name:ident) => { #[wasm_bindgen] impl $struct_name { @@ -112,15 +124,43 @@ macro_rules! generate_common_methods { pub fn to_bytes(&self) -> JsResult> { to_js_bytes(&self.0) } + } + }; +} - #[wasm_bindgen] - pub fn equals(&self, other: &$struct_name) -> bool { - self.0 == other.0 +macro_rules! generate_boxed_bytes_serialization { + ($struct_name:ident, $inner_struct_name:ident) => { + #[wasm_bindgen] + impl $struct_name { + #[wasm_bindgen(js_name = "fromBytes")] + pub fn from_bytes(bytes: &[u8]) -> JsResult<$struct_name> { + $inner_struct_name::from_bytes(bytes) + .map_err(map_js_err) + .map(Self) + } + + #[wasm_bindgen(js_name = "toBytes")] + pub fn to_bytes(&self) -> JsResult> { + let bytes = self.0.to_bytes().map_err(map_js_err)?; + let bytes: Box<[u8]> = bytes.as_slice().into(); + Ok(bytes) + } + + #[wasm_bindgen(js_name = "serializedSize")] + pub fn serialized_size() -> usize { + $inner_struct_name::serialized_size() } } }; } +macro_rules! generate_common_methods { + ($struct_name:ident) => { + generate_equals!($struct_name); + generate_bytes_serialization!($struct_name); + }; +} + #[derive(TryFromJsValue)] #[wasm_bindgen] #[derive(Clone, Debug, derive_more::AsRef, derive_more::From)] @@ -135,13 +175,16 @@ pub struct DecryptionSharePrecomputed(tpke::api::DecryptionSharePrecomputed); generate_common_methods!(DecryptionSharePrecomputed); +type InnerPublicKey = api::PublicKey; + #[wasm_bindgen] #[derive( Clone, Debug, derive_more::AsRef, derive_more::From, derive_more::Into, )] -pub struct FerveoPublicKey(api::PublicKey); +pub struct FerveoPublicKey(InnerPublicKey); -generate_common_methods!(FerveoPublicKey); +generate_equals!(FerveoPublicKey); +generate_boxed_bytes_serialization!(FerveoPublicKey, InnerPublicKey); #[wasm_bindgen] #[derive( @@ -212,39 +255,20 @@ pub fn decrypt_with_shared_secret( .map_err(map_js_err) } -#[wasm_bindgen] -pub struct DkgPublicKey(api::DkgPublicKey); +type InnerDkgPublicKey = api::DkgPublicKey; #[wasm_bindgen] -impl DkgPublicKey { - #[wasm_bindgen(js_name = "fromBytes")] - pub fn from_bytes(bytes: &[u8]) -> JsResult { - api::DkgPublicKey::from_bytes(bytes) - .map_err(map_js_err) - .map(Self) - } +pub struct DkgPublicKey(InnerDkgPublicKey); - #[wasm_bindgen(js_name = "toBytes")] - pub fn to_bytes(&self) -> JsResult> { - let bytes = self.0.to_bytes().map_err(map_js_err)?; - let bytes: Box<[u8]> = bytes.as_slice().into(); - Ok(bytes) - } +generate_equals!(DkgPublicKey); +generate_boxed_bytes_serialization!(DkgPublicKey, InnerDkgPublicKey); +#[wasm_bindgen] +impl DkgPublicKey { #[wasm_bindgen] pub fn random() -> DkgPublicKey { Self(api::DkgPublicKey::random()) } - - #[wasm_bindgen(js_name = "serializedSize")] - pub fn serialized_size() -> usize { - api::DkgPublicKey::serialized_size() - } - - #[wasm_bindgen] - pub fn equals(&self, other: &DkgPublicKey) -> bool { - self.0 == other.0 - } } #[wasm_bindgen]