From cea467e0bd48a096f70dd1c7ca24a7e4bd88b3d4 Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Thu, 13 Jul 2023 10:16:09 +0200 Subject: [PATCH] add equality to FerveoVariant python bindings --- ferveo-python/ferveo/__init__.pyi | 19 ++++++++++++++++++ ferveo-python/test/test_serialization.py | 3 +++ ferveo/src/api.rs | 13 +++++++++++- ferveo/src/bindings_python.rs | 10 +++++----- ferveo/src/bindings_wasm.rs | 25 +++++++++++++++++++----- 5 files changed, 59 insertions(+), 11 deletions(-) diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi index 6c32cdc4..51f982a6 100644 --- a/ferveo-python/ferveo/__init__.pyi +++ b/ferveo-python/ferveo/__init__.pyi @@ -1,5 +1,6 @@ from typing import Sequence, final + @final class Keypair: @staticmethod @@ -24,6 +25,7 @@ class Keypair: def public_key(self) -> FerveoPublicKey: ... + @final class FerveoPublicKey: @staticmethod @@ -40,6 +42,10 @@ class FerveoPublicKey: def serialized_size() -> int: ... + def __eq__(self, other: object) -> bool: + ... + + @final class Validator: @@ -50,6 +56,7 @@ class Validator: public_key: FerveoPublicKey + @final class Transcript: @staticmethod @@ -59,6 +66,7 @@ class Transcript: def __bytes__(self) -> bytes: ... + @final class DkgPublicKey: @staticmethod @@ -72,6 +80,7 @@ class DkgPublicKey: def serialized_size() -> int: ... + @final class ValidatorMessage: @@ -85,6 +94,7 @@ class ValidatorMessage: validator: Validator transcript: Transcript + @final class Dkg: @@ -106,6 +116,7 @@ class Dkg: def aggregate_transcripts(self, messages: Sequence[ValidatorMessage]) -> AggregatedTranscript: ... + @final class Ciphertext: @staticmethod @@ -115,6 +126,7 @@ class Ciphertext: def __bytes__(self) -> bytes: ... + @final class DecryptionShareSimple: @staticmethod @@ -123,6 +135,8 @@ class DecryptionShareSimple: def __bytes__(self) -> bytes: ... + + @final class DecryptionSharePrecomputed: @staticmethod @@ -132,6 +146,7 @@ class DecryptionSharePrecomputed: def __bytes__(self) -> bytes: ... + @final class AggregatedTranscript: @@ -166,6 +181,7 @@ class AggregatedTranscript: def __bytes__(self) -> bytes: ... + @final class SharedSecret: @@ -182,6 +198,9 @@ class FerveoVariant: simple: FerveoVariant precomputed: FerveoVariant + def __eq__(self, other: object) -> bool: + ... + def encrypt(message: bytes, aad: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext: ... diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py index 9eaccdbd..78e07868 100644 --- a/ferveo-python/test/test_serialization.py +++ b/ferveo-python/test/test_serialization.py @@ -83,3 +83,6 @@ def test_public_key_serialization(): def test_ferveo_variant_serialization(): assert str(FerveoVariant.precomputed) == "FerveoVariant::Precomputed" assert str(FerveoVariant.simple) == "FerveoVariant::Simple" + assert FerveoVariant.precomputed == FerveoVariant.precomputed + assert FerveoVariant.simple == FerveoVariant.simple + assert FerveoVariant.precomputed != FerveoVariant.simple diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index c909cfed..ccf9ff5f 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -23,6 +23,8 @@ pub type ValidatorMessage = (Validator, Transcript); #[cfg(feature = "bindings-python")] use crate::bindings_python; +#[cfg(feature = "bindings-wasm")] +use crate::bindings_wasm; pub use crate::EthereumAddress; use crate::{ do_verify_aggregation, Error, PVSSMap, PubliclyVerifiableParams, @@ -72,7 +74,9 @@ pub fn decrypt_with_shared_secret( } /// The ferveo variant to use for the decryption share derivation. -#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone)] +#[derive( + PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd, +)] pub enum FerveoVariant { /// The simple variant requires m of n shares to decrypt Simple, @@ -110,6 +114,13 @@ impl From for FerveoVariant { } } +#[cfg(feature = "bindings-wasm")] +impl From for FerveoVariant { + fn from(variant: bindings_wasm::FerveoVariant) -> Self { + variant.0 + } +} + #[serde_as] #[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct DkgPublicKey( diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index 51088e23..c01dffc4 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -271,7 +271,9 @@ pub fn decrypt_with_shared_secret( } #[pyclass(module = "ferveo")] -#[derive(Clone)] +#[derive( + Clone, PartialEq, PartialOrd, Eq, derive_more::From, derive_more::AsRef, +)] pub struct FerveoVariant(pub(crate) api::FerveoVariant); #[pymethods] @@ -289,11 +291,9 @@ impl FerveoVariant { fn __str__(&self) -> String { self.0.to_string() } -} -impl From for FerveoVariant { - fn from(variant: api::FerveoVariant) -> Self { - Self(variant) + fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { + richcmp(self, other, op) } } diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index ab610160..7b9ae484 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -162,18 +162,33 @@ macro_rules! generate_common_methods { } #[wasm_bindgen] -pub struct FerveoVariant {} +#[derive(Clone, Debug, derive_more::AsRef, derive_more::From)] +pub struct FerveoVariant(pub(crate) api::FerveoVariant); + +impl fmt::Display for FerveoVariant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.0) + } +} + +generate_common_methods!(FerveoVariant); #[wasm_bindgen] impl FerveoVariant { #[wasm_bindgen(js_name = "precomputed", getter)] - pub fn precomputed() -> String { - api::FerveoVariant::Precomputed.as_str().to_string() + pub fn precomputed() -> FerveoVariant { + FerveoVariant(api::FerveoVariant::Precomputed) } #[wasm_bindgen(js_name = "simple", getter)] - pub fn simple() -> String { - api::FerveoVariant::Simple.as_str().to_string() + pub fn simple() -> FerveoVariant { + FerveoVariant(api::FerveoVariant::Simple) + } + + #[allow(clippy::inherent_to_string_shadow_display)] + #[wasm_bindgen(js_name = "toString")] + pub fn to_string(&self) -> String { + self.0.to_string() } }