From cea467e0bd48a096f70dd1c7ca24a7e4bd88b3d4 Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Thu, 13 Jul 2023 10:16:09 +0200
Subject: [PATCH] add equality to FerveoVariant python bindings
---
ferveo-python/ferveo/__init__.pyi | 19 ++++++++++++++++++
ferveo-python/test/test_serialization.py | 3 +++
ferveo/src/api.rs | 13 +++++++++++-
ferveo/src/bindings_python.rs | 10 +++++-----
ferveo/src/bindings_wasm.rs | 25 +++++++++++++++++++-----
5 files changed, 59 insertions(+), 11 deletions(-)
diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi
index 6c32cdc4..51f982a6 100644
--- a/ferveo-python/ferveo/__init__.pyi
+++ b/ferveo-python/ferveo/__init__.pyi
@@ -1,5 +1,6 @@
from typing import Sequence, final
+
@final
class Keypair:
@staticmethod
@@ -24,6 +25,7 @@ class Keypair:
def public_key(self) -> FerveoPublicKey:
...
+
@final
class FerveoPublicKey:
@staticmethod
@@ -40,6 +42,10 @@ class FerveoPublicKey:
def serialized_size() -> int:
...
+ def __eq__(self, other: object) -> bool:
+ ...
+
+
@final
class Validator:
@@ -50,6 +56,7 @@ class Validator:
public_key: FerveoPublicKey
+
@final
class Transcript:
@staticmethod
@@ -59,6 +66,7 @@ class Transcript:
def __bytes__(self) -> bytes:
...
+
@final
class DkgPublicKey:
@staticmethod
@@ -72,6 +80,7 @@ class DkgPublicKey:
def serialized_size() -> int:
...
+
@final
class ValidatorMessage:
@@ -85,6 +94,7 @@ class ValidatorMessage:
validator: Validator
transcript: Transcript
+
@final
class Dkg:
@@ -106,6 +116,7 @@ class Dkg:
def aggregate_transcripts(self, messages: Sequence[ValidatorMessage]) -> AggregatedTranscript:
...
+
@final
class Ciphertext:
@staticmethod
@@ -115,6 +126,7 @@ class Ciphertext:
def __bytes__(self) -> bytes:
...
+
@final
class DecryptionShareSimple:
@staticmethod
@@ -123,6 +135,8 @@ class DecryptionShareSimple:
def __bytes__(self) -> bytes:
...
+
+
@final
class DecryptionSharePrecomputed:
@staticmethod
@@ -132,6 +146,7 @@ class DecryptionSharePrecomputed:
def __bytes__(self) -> bytes:
...
+
@final
class AggregatedTranscript:
@@ -166,6 +181,7 @@ class AggregatedTranscript:
def __bytes__(self) -> bytes:
...
+
@final
class SharedSecret:
@@ -182,6 +198,9 @@ class FerveoVariant:
simple: FerveoVariant
precomputed: FerveoVariant
+ def __eq__(self, other: object) -> bool:
+ ...
+
def encrypt(message: bytes, aad: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext:
...
diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py
index 9eaccdbd..78e07868 100644
--- a/ferveo-python/test/test_serialization.py
+++ b/ferveo-python/test/test_serialization.py
@@ -83,3 +83,6 @@ def test_public_key_serialization():
def test_ferveo_variant_serialization():
assert str(FerveoVariant.precomputed) == "FerveoVariant::Precomputed"
assert str(FerveoVariant.simple) == "FerveoVariant::Simple"
+ assert FerveoVariant.precomputed == FerveoVariant.precomputed
+ assert FerveoVariant.simple == FerveoVariant.simple
+ assert FerveoVariant.precomputed != FerveoVariant.simple
diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs
index c909cfed..ccf9ff5f 100644
--- a/ferveo/src/api.rs
+++ b/ferveo/src/api.rs
@@ -23,6 +23,8 @@ pub type ValidatorMessage = (Validator, Transcript);
#[cfg(feature = "bindings-python")]
use crate::bindings_python;
+#[cfg(feature = "bindings-wasm")]
+use crate::bindings_wasm;
pub use crate::EthereumAddress;
use crate::{
do_verify_aggregation, Error, PVSSMap, PubliclyVerifiableParams,
@@ -72,7 +74,9 @@ pub fn decrypt_with_shared_secret(
}
/// The ferveo variant to use for the decryption share derivation.
-#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone)]
+#[derive(
+ PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone, PartialOrd,
+)]
pub enum FerveoVariant {
/// The simple variant requires m of n shares to decrypt
Simple,
@@ -110,6 +114,13 @@ impl From for FerveoVariant {
}
}
+#[cfg(feature = "bindings-wasm")]
+impl From for FerveoVariant {
+ fn from(variant: bindings_wasm::FerveoVariant) -> Self {
+ variant.0
+ }
+}
+
#[serde_as]
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DkgPublicKey(
diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs
index 51088e23..c01dffc4 100644
--- a/ferveo/src/bindings_python.rs
+++ b/ferveo/src/bindings_python.rs
@@ -271,7 +271,9 @@ pub fn decrypt_with_shared_secret(
}
#[pyclass(module = "ferveo")]
-#[derive(Clone)]
+#[derive(
+ Clone, PartialEq, PartialOrd, Eq, derive_more::From, derive_more::AsRef,
+)]
pub struct FerveoVariant(pub(crate) api::FerveoVariant);
#[pymethods]
@@ -289,11 +291,9 @@ impl FerveoVariant {
fn __str__(&self) -> String {
self.0.to_string()
}
-}
-impl From for FerveoVariant {
- fn from(variant: api::FerveoVariant) -> Self {
- Self(variant)
+ fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult {
+ richcmp(self, other, op)
}
}
diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs
index ab610160..7b9ae484 100644
--- a/ferveo/src/bindings_wasm.rs
+++ b/ferveo/src/bindings_wasm.rs
@@ -162,18 +162,33 @@ macro_rules! generate_common_methods {
}
#[wasm_bindgen]
-pub struct FerveoVariant {}
+#[derive(Clone, Debug, derive_more::AsRef, derive_more::From)]
+pub struct FerveoVariant(pub(crate) api::FerveoVariant);
+
+impl fmt::Display for FerveoVariant {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{:?}", self.0)
+ }
+}
+
+generate_common_methods!(FerveoVariant);
#[wasm_bindgen]
impl FerveoVariant {
#[wasm_bindgen(js_name = "precomputed", getter)]
- pub fn precomputed() -> String {
- api::FerveoVariant::Precomputed.as_str().to_string()
+ pub fn precomputed() -> FerveoVariant {
+ FerveoVariant(api::FerveoVariant::Precomputed)
}
#[wasm_bindgen(js_name = "simple", getter)]
- pub fn simple() -> String {
- api::FerveoVariant::Simple.as_str().to_string()
+ pub fn simple() -> FerveoVariant {
+ FerveoVariant(api::FerveoVariant::Simple)
+ }
+
+ #[allow(clippy::inherent_to_string_shadow_display)]
+ #[wasm_bindgen(js_name = "toString")]
+ pub fn to_string(&self) -> String {
+ self.0.to_string()
}
}