From a9c431aad08b5d8d54b7fdb84fc3d479aebaa036 Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Fri, 25 Aug 2023 13:21:58 +0200 Subject: [PATCH] address pr comments in #74 --- nucypher-core-python/src/lib.rs | 19 +++++++---- nucypher-core-wasm/src/lib.rs | 8 ++--- nucypher-core-wasm/tests/wasm.rs | 6 ++-- nucypher-core/src/access_control.rs | 39 ++++++++++------------ nucypher-core/src/threshold_message_kit.rs | 5 +-- 5 files changed, 40 insertions(+), 37 deletions(-) diff --git a/nucypher-core-python/src/lib.rs b/nucypher-core-python/src/lib.rs index dda04728..c3096392 100644 --- a/nucypher-core-python/src/lib.rs +++ b/nucypher-core-python/src/lib.rs @@ -71,7 +71,6 @@ where builtins.getattr("hash")?.call1(((arg1, arg2),))?.extract() }) } - #[pyclass(module = "nucypher_core")] #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, derive_more::AsRef)] pub struct Address { @@ -759,9 +758,12 @@ impl AuthenticatedData { } } - pub fn aad(&self, py: Python) -> PyObject { - let result = self.backend.aad(); - PyBytes::new(py, result.as_ref()).into() + pub fn aad(&self, py: Python) -> PyResult { + let result = self + .backend + .aad() + .map_err(|err| PyValueError::new_err(format!("{}", err)))?; + Ok(PyBytes::new(py, &result).into()) } #[getter] @@ -828,9 +830,12 @@ impl AccessControlPolicy { } } - pub fn aad(&self, py: Python) -> PyObject { - let result = self.backend.auth_data.aad(); - PyBytes::new(py, result.as_ref()).into() + pub fn aad(&self, py: Python) -> PyResult { + let result = self + .backend + .aad() + .map_err(|err| PyValueError::new_err(format!("{}", err)))?; + Ok(PyBytes::new(py, &result).into()) } #[getter] diff --git a/nucypher-core-wasm/src/lib.rs b/nucypher-core-wasm/src/lib.rs index 09c8d8d8..0869a4d2 100644 --- a/nucypher-core-wasm/src/lib.rs +++ b/nucypher-core-wasm/src/lib.rs @@ -680,8 +680,8 @@ impl AuthenticatedData { ))) } - pub fn aad(&self) -> Box<[u8]> { - self.0.aad() + pub fn aad(&self) -> Result, Error> { + self.0.aad().map_err(map_js_err) } #[wasm_bindgen(getter, js_name = publicKey)] @@ -742,8 +742,8 @@ impl AccessControlPolicy { ))) } - pub fn aad(&self) -> Box<[u8]> { - self.0.aad() + pub fn aad(&self) -> Result, Error> { + self.0.aad().map_err(map_js_err) } #[wasm_bindgen(getter, js_name = publicKey)] diff --git a/nucypher-core-wasm/tests/wasm.rs b/nucypher-core-wasm/tests/wasm.rs index 459acbfb..a149b883 100644 --- a/nucypher-core-wasm/tests/wasm.rs +++ b/nucypher-core-wasm/tests/wasm.rs @@ -708,7 +708,7 @@ fn threshold_decryption_request() { let acp = AccessControlPolicy::new(&auth_data, authorization).unwrap(); let message = "my-message".as_bytes(); - let ciphertext = ferveo_encrypt(message, &acp.aad(), &dkg_pk).unwrap(); + let ciphertext = ferveo_encrypt(message, &acp.aad().unwrap(), &dkg_pk).unwrap(); let ciphertext_header = ciphertext.header().unwrap(); let request = ThresholdDecryptionRequest::new( @@ -817,7 +817,7 @@ fn authenticated_data() { let mut expected_aad = dkg_pk.to_bytes().unwrap().to_vec(); expected_aad.extend(conditions.as_bytes()); - assert_eq!(auth_data.aad(), expected_aad.into_boxed_slice()); + assert_eq!(auth_data.aad().unwrap(), expected_aad.into_boxed_slice()); // mimic serialization/deserialization over the wire let serialized_auth_data = auth_data.to_bytes(); @@ -891,7 +891,7 @@ fn threshold_message_kit() { let acp = AccessControlPolicy::new(&auth_data, authorization).unwrap(); let data = "The Tyranny of Merit".as_bytes(); - let ciphertext = ferveo_encrypt(data, &acp.aad(), &dkg_pk).unwrap(); + let ciphertext = ferveo_encrypt(data, &acp.aad().unwrap(), &dkg_pk).unwrap(); let tmk = ThresholdMessageKit::new(&ciphertext, &acp); diff --git a/nucypher-core/src/access_control.rs b/nucypher-core/src/access_control.rs index 989ac521..28e6af25 100644 --- a/nucypher-core/src/access_control.rs +++ b/nucypher-core/src/access_control.rs @@ -1,6 +1,5 @@ use alloc::boxed::Box; use alloc::string::String; -use alloc::vec::Vec; use ferveo::api::{encrypt, Ciphertext, DkgPublicKey, SecretBox}; use ferveo::Error; @@ -13,7 +12,7 @@ use crate::versioning::{ }; /// Authenticated data for encrypted data. -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct AuthenticatedData { /// The public key for the encrypted data pub public_key: DkgPublicKey, @@ -22,6 +21,8 @@ pub struct AuthenticatedData { pub conditions: Option, } +impl Eq for AuthenticatedData {} + impl AuthenticatedData { /// Creates a new access control policy. pub fn new(public_key: &DkgPublicKey, conditions: Option<&Conditions>) -> Self { @@ -32,24 +33,20 @@ impl AuthenticatedData { } /// Return the aad. - pub fn aad(&self) -> Box<[u8]> { - let public_key_bytes = self.public_key.to_bytes().unwrap(); - let condition_bytes = self.conditions.as_ref().unwrap().as_ref().as_bytes(); - let mut result = Vec::with_capacity(public_key_bytes.len() + condition_bytes.len()); - result.extend(public_key_bytes); - result.extend(condition_bytes); - result.into_boxed_slice() + pub fn aad(&self) -> Result, Error> { + Ok([ + self.public_key.to_bytes()?.to_vec(), + self.conditions + .as_ref() + .map(|c| c.as_ref().as_bytes()) + .unwrap_or_default() + .to_vec(), + ] + .concat() + .into_boxed_slice()) } } -impl PartialEq for AuthenticatedData { - fn eq(&self, other: &Self) -> bool { - self.public_key == other.public_key && self.conditions == other.conditions - } -} - -impl Eq for AuthenticatedData {} - impl<'a> ProtocolObjectInner<'a> for AuthenticatedData { fn version() -> (u16, u16) { (1, 0) @@ -83,7 +80,7 @@ pub fn encrypt_for_dkg( let auth_data = AuthenticatedData::new(public_key, conditions); let ciphertext = encrypt( SecretBox::new(data.to_vec()), - auth_data.aad().as_ref(), + auth_data.aad()?.as_ref(), public_key, )?; Ok((ciphertext, auth_data)) @@ -110,7 +107,7 @@ impl AccessControlPolicy { } /// Return the aad. - pub fn aad(&self) -> Box<[u8]> { + pub fn aad(&self) -> Result, Error> { self.auth_data.aad() } @@ -167,7 +164,7 @@ mod tests { // check aad for auth data; expected to be dkg public key + conditions let mut expected_aad = dkg_pk.to_bytes().unwrap().to_vec(); expected_aad.extend(conditions.as_ref().as_bytes()); - let auth_data_aad = auth_data.aad(); + let auth_data_aad = auth_data.aad().unwrap(); assert_eq!(expected_aad.into_boxed_slice(), auth_data_aad); assert_eq!(auth_data.public_key, dkg_pk); @@ -193,7 +190,7 @@ mod tests { let acp = AccessControlPolicy::new(&auth_data, authorization); // check that aad for auth_data and acp are the same - assert_eq!(auth_data.aad(), acp.aad()); + assert_eq!(auth_data.aad().unwrap(), acp.aad().unwrap()); // mimic serialization/deserialization over the wire let serialized_acp = acp.to_bytes(); diff --git a/nucypher-core/src/threshold_message_kit.rs b/nucypher-core/src/threshold_message_kit.rs index 3ed6b53b..3f4adfec 100644 --- a/nucypher-core/src/threshold_message_kit.rs +++ b/nucypher-core/src/threshold_message_kit.rs @@ -42,7 +42,7 @@ impl ThresholdMessageKit { ) -> Result, Error> { ferveo::api::decrypt_with_shared_secret( &self.ciphertext, - self.acp.aad().as_ref(), + self.acp.aad()?.as_ref(), shared_secret, ) } @@ -92,7 +92,8 @@ mod tests { authorization, ); - let ciphertext = ferveo_encrypt(SecretBox::new(data), &acp.aad(), &dkg_pk).unwrap(); + let ciphertext = + ferveo_encrypt(SecretBox::new(data), &acp.aad().unwrap(), &dkg_pk).unwrap(); let tmk = ThresholdMessageKit::new(&ciphertext, &acp); // mimic serialization/deserialization over the wire