From c3b749732cc2d560240f7a1043c18d10138e30e9 Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Thu, 13 Jul 2023 10:16:09 +0200
Subject: [PATCH] add equality to FerveoVariant python bindings
---
ferveo-python/ferveo/__init__.pyi | 16 ++++++++++++++++
ferveo-python/test/test_serialization.py | 3 +++
ferveo/src/bindings_python.rs | 4 ++++
3 files changed, 23 insertions(+)
diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi
index 6c32cdc4..7e45d878 100644
--- a/ferveo-python/ferveo/__init__.pyi
+++ b/ferveo-python/ferveo/__init__.pyi
@@ -1,5 +1,6 @@
from typing import Sequence, final
+
@final
class Keypair:
@staticmethod
@@ -24,6 +25,7 @@ class Keypair:
def public_key(self) -> FerveoPublicKey:
...
+
@final
class FerveoPublicKey:
@staticmethod
@@ -40,6 +42,7 @@ class FerveoPublicKey:
def serialized_size() -> int:
...
+
@final
class Validator:
@@ -50,6 +53,7 @@ class Validator:
public_key: FerveoPublicKey
+
@final
class Transcript:
@staticmethod
@@ -59,6 +63,7 @@ class Transcript:
def __bytes__(self) -> bytes:
...
+
@final
class DkgPublicKey:
@staticmethod
@@ -72,6 +77,7 @@ class DkgPublicKey:
def serialized_size() -> int:
...
+
@final
class ValidatorMessage:
@@ -85,6 +91,7 @@ class ValidatorMessage:
validator: Validator
transcript: Transcript
+
@final
class Dkg:
@@ -106,6 +113,7 @@ class Dkg:
def aggregate_transcripts(self, messages: Sequence[ValidatorMessage]) -> AggregatedTranscript:
...
+
@final
class Ciphertext:
@staticmethod
@@ -115,6 +123,7 @@ class Ciphertext:
def __bytes__(self) -> bytes:
...
+
@final
class DecryptionShareSimple:
@staticmethod
@@ -123,6 +132,8 @@ class DecryptionShareSimple:
def __bytes__(self) -> bytes:
...
+
+
@final
class DecryptionSharePrecomputed:
@staticmethod
@@ -132,6 +143,7 @@ class DecryptionSharePrecomputed:
def __bytes__(self) -> bytes:
...
+
@final
class AggregatedTranscript:
@@ -166,6 +178,7 @@ class AggregatedTranscript:
def __bytes__(self) -> bytes:
...
+
@final
class SharedSecret:
@@ -182,6 +195,9 @@ class FerveoVariant:
simple: FerveoVariant
precomputed: FerveoVariant
+ def __eq__(self, other: object) -> bool:
+ ...
+
def encrypt(message: bytes, aad: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext:
...
diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py
index 9eaccdbd..78e07868 100644
--- a/ferveo-python/test/test_serialization.py
+++ b/ferveo-python/test/test_serialization.py
@@ -83,3 +83,6 @@ def test_public_key_serialization():
def test_ferveo_variant_serialization():
assert str(FerveoVariant.precomputed) == "FerveoVariant::Precomputed"
assert str(FerveoVariant.simple) == "FerveoVariant::Simple"
+ assert FerveoVariant.precomputed == FerveoVariant.precomputed
+ assert FerveoVariant.simple == FerveoVariant.simple
+ assert FerveoVariant.precomputed != FerveoVariant.simple
diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs
index 51088e23..f0708474 100644
--- a/ferveo/src/bindings_python.rs
+++ b/ferveo/src/bindings_python.rs
@@ -289,6 +289,10 @@ impl FerveoVariant {
fn __str__(&self) -> String {
self.0.to_string()
}
+
+ fn __eq__(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
}
impl From for FerveoVariant {