From a9c431aad08b5d8d54b7fdb84fc3d479aebaa036 Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Fri, 25 Aug 2023 13:21:58 +0200
Subject: [PATCH] address pr comments in #74
---
nucypher-core-python/src/lib.rs | 19 +++++++----
nucypher-core-wasm/src/lib.rs | 8 ++---
nucypher-core-wasm/tests/wasm.rs | 6 ++--
nucypher-core/src/access_control.rs | 39 ++++++++++------------
nucypher-core/src/threshold_message_kit.rs | 5 +--
5 files changed, 40 insertions(+), 37 deletions(-)
diff --git a/nucypher-core-python/src/lib.rs b/nucypher-core-python/src/lib.rs
index dda04728..c3096392 100644
--- a/nucypher-core-python/src/lib.rs
+++ b/nucypher-core-python/src/lib.rs
@@ -71,7 +71,6 @@ where
builtins.getattr("hash")?.call1(((arg1, arg2),))?.extract()
})
}
-
#[pyclass(module = "nucypher_core")]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, derive_more::AsRef)]
pub struct Address {
@@ -759,9 +758,12 @@ impl AuthenticatedData {
}
}
- pub fn aad(&self, py: Python) -> PyObject {
- let result = self.backend.aad();
- PyBytes::new(py, result.as_ref()).into()
+ pub fn aad(&self, py: Python) -> PyResult {
+ let result = self
+ .backend
+ .aad()
+ .map_err(|err| PyValueError::new_err(format!("{}", err)))?;
+ Ok(PyBytes::new(py, &result).into())
}
#[getter]
@@ -828,9 +830,12 @@ impl AccessControlPolicy {
}
}
- pub fn aad(&self, py: Python) -> PyObject {
- let result = self.backend.auth_data.aad();
- PyBytes::new(py, result.as_ref()).into()
+ pub fn aad(&self, py: Python) -> PyResult {
+ let result = self
+ .backend
+ .aad()
+ .map_err(|err| PyValueError::new_err(format!("{}", err)))?;
+ Ok(PyBytes::new(py, &result).into())
}
#[getter]
diff --git a/nucypher-core-wasm/src/lib.rs b/nucypher-core-wasm/src/lib.rs
index 09c8d8d8..0869a4d2 100644
--- a/nucypher-core-wasm/src/lib.rs
+++ b/nucypher-core-wasm/src/lib.rs
@@ -680,8 +680,8 @@ impl AuthenticatedData {
)))
}
- pub fn aad(&self) -> Box<[u8]> {
- self.0.aad()
+ pub fn aad(&self) -> Result, Error> {
+ self.0.aad().map_err(map_js_err)
}
#[wasm_bindgen(getter, js_name = publicKey)]
@@ -742,8 +742,8 @@ impl AccessControlPolicy {
)))
}
- pub fn aad(&self) -> Box<[u8]> {
- self.0.aad()
+ pub fn aad(&self) -> Result, Error> {
+ self.0.aad().map_err(map_js_err)
}
#[wasm_bindgen(getter, js_name = publicKey)]
diff --git a/nucypher-core-wasm/tests/wasm.rs b/nucypher-core-wasm/tests/wasm.rs
index 459acbfb..a149b883 100644
--- a/nucypher-core-wasm/tests/wasm.rs
+++ b/nucypher-core-wasm/tests/wasm.rs
@@ -708,7 +708,7 @@ fn threshold_decryption_request() {
let acp = AccessControlPolicy::new(&auth_data, authorization).unwrap();
let message = "my-message".as_bytes();
- let ciphertext = ferveo_encrypt(message, &acp.aad(), &dkg_pk).unwrap();
+ let ciphertext = ferveo_encrypt(message, &acp.aad().unwrap(), &dkg_pk).unwrap();
let ciphertext_header = ciphertext.header().unwrap();
let request = ThresholdDecryptionRequest::new(
@@ -817,7 +817,7 @@ fn authenticated_data() {
let mut expected_aad = dkg_pk.to_bytes().unwrap().to_vec();
expected_aad.extend(conditions.as_bytes());
- assert_eq!(auth_data.aad(), expected_aad.into_boxed_slice());
+ assert_eq!(auth_data.aad().unwrap(), expected_aad.into_boxed_slice());
// mimic serialization/deserialization over the wire
let serialized_auth_data = auth_data.to_bytes();
@@ -891,7 +891,7 @@ fn threshold_message_kit() {
let acp = AccessControlPolicy::new(&auth_data, authorization).unwrap();
let data = "The Tyranny of Merit".as_bytes();
- let ciphertext = ferveo_encrypt(data, &acp.aad(), &dkg_pk).unwrap();
+ let ciphertext = ferveo_encrypt(data, &acp.aad().unwrap(), &dkg_pk).unwrap();
let tmk = ThresholdMessageKit::new(&ciphertext, &acp);
diff --git a/nucypher-core/src/access_control.rs b/nucypher-core/src/access_control.rs
index 989ac521..28e6af25 100644
--- a/nucypher-core/src/access_control.rs
+++ b/nucypher-core/src/access_control.rs
@@ -1,6 +1,5 @@
use alloc::boxed::Box;
use alloc::string::String;
-use alloc::vec::Vec;
use ferveo::api::{encrypt, Ciphertext, DkgPublicKey, SecretBox};
use ferveo::Error;
@@ -13,7 +12,7 @@ use crate::versioning::{
};
/// Authenticated data for encrypted data.
-#[derive(Debug, Serialize, Deserialize, Clone)]
+#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct AuthenticatedData {
/// The public key for the encrypted data
pub public_key: DkgPublicKey,
@@ -22,6 +21,8 @@ pub struct AuthenticatedData {
pub conditions: Option,
}
+impl Eq for AuthenticatedData {}
+
impl AuthenticatedData {
/// Creates a new access control policy.
pub fn new(public_key: &DkgPublicKey, conditions: Option<&Conditions>) -> Self {
@@ -32,24 +33,20 @@ impl AuthenticatedData {
}
/// Return the aad.
- pub fn aad(&self) -> Box<[u8]> {
- let public_key_bytes = self.public_key.to_bytes().unwrap();
- let condition_bytes = self.conditions.as_ref().unwrap().as_ref().as_bytes();
- let mut result = Vec::with_capacity(public_key_bytes.len() + condition_bytes.len());
- result.extend(public_key_bytes);
- result.extend(condition_bytes);
- result.into_boxed_slice()
+ pub fn aad(&self) -> Result, Error> {
+ Ok([
+ self.public_key.to_bytes()?.to_vec(),
+ self.conditions
+ .as_ref()
+ .map(|c| c.as_ref().as_bytes())
+ .unwrap_or_default()
+ .to_vec(),
+ ]
+ .concat()
+ .into_boxed_slice())
}
}
-impl PartialEq for AuthenticatedData {
- fn eq(&self, other: &Self) -> bool {
- self.public_key == other.public_key && self.conditions == other.conditions
- }
-}
-
-impl Eq for AuthenticatedData {}
-
impl<'a> ProtocolObjectInner<'a> for AuthenticatedData {
fn version() -> (u16, u16) {
(1, 0)
@@ -83,7 +80,7 @@ pub fn encrypt_for_dkg(
let auth_data = AuthenticatedData::new(public_key, conditions);
let ciphertext = encrypt(
SecretBox::new(data.to_vec()),
- auth_data.aad().as_ref(),
+ auth_data.aad()?.as_ref(),
public_key,
)?;
Ok((ciphertext, auth_data))
@@ -110,7 +107,7 @@ impl AccessControlPolicy {
}
/// Return the aad.
- pub fn aad(&self) -> Box<[u8]> {
+ pub fn aad(&self) -> Result, Error> {
self.auth_data.aad()
}
@@ -167,7 +164,7 @@ mod tests {
// check aad for auth data; expected to be dkg public key + conditions
let mut expected_aad = dkg_pk.to_bytes().unwrap().to_vec();
expected_aad.extend(conditions.as_ref().as_bytes());
- let auth_data_aad = auth_data.aad();
+ let auth_data_aad = auth_data.aad().unwrap();
assert_eq!(expected_aad.into_boxed_slice(), auth_data_aad);
assert_eq!(auth_data.public_key, dkg_pk);
@@ -193,7 +190,7 @@ mod tests {
let acp = AccessControlPolicy::new(&auth_data, authorization);
// check that aad for auth_data and acp are the same
- assert_eq!(auth_data.aad(), acp.aad());
+ assert_eq!(auth_data.aad().unwrap(), acp.aad().unwrap());
// mimic serialization/deserialization over the wire
let serialized_acp = acp.to_bytes();
diff --git a/nucypher-core/src/threshold_message_kit.rs b/nucypher-core/src/threshold_message_kit.rs
index 3ed6b53b..3f4adfec 100644
--- a/nucypher-core/src/threshold_message_kit.rs
+++ b/nucypher-core/src/threshold_message_kit.rs
@@ -42,7 +42,7 @@ impl ThresholdMessageKit {
) -> Result, Error> {
ferveo::api::decrypt_with_shared_secret(
&self.ciphertext,
- self.acp.aad().as_ref(),
+ self.acp.aad()?.as_ref(),
shared_secret,
)
}
@@ -92,7 +92,8 @@ mod tests {
authorization,
);
- let ciphertext = ferveo_encrypt(SecretBox::new(data), &acp.aad(), &dkg_pk).unwrap();
+ let ciphertext =
+ ferveo_encrypt(SecretBox::new(data), &acp.aad().unwrap(), &dkg_pk).unwrap();
let tmk = ThresholdMessageKit::new(&ciphertext, &acp);
// mimic serialization/deserialization over the wire