Skip to content

Commit

Permalink
feat: expose ferveo variant in bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Jul 5, 2023
1 parent 3bc28d7 commit e8d0598
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ferveo-python/ferveo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DkgPublicKey,
SharedSecret,
ValidatorMessage,
FerveoVariant,
ThresholdEncryptionError,
InvalidShareNumberParameter,
InvalidDkgStateToDeal,
Expand All @@ -32,4 +33,5 @@
ValidatorsNotSorted,
ValidatorPublicKeyMismatch,
SerializationError,
InvalidVariant,
)
12 changes: 12 additions & 0 deletions ferveo-python/ferveo/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ class SharedSecret:
...


class FerveoVariant:
@staticmethod
def simple() -> str: ...

@staticmethod
def precomputed() -> str: ...


def encrypt(message: bytes, add: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext:
...

Expand Down Expand Up @@ -263,3 +271,7 @@ class ValidatorPublicKeyMismatch(Exception):

class SerializationError(Exception):
pass


class InvalidVariant(Exception):
pass
6 changes: 6 additions & 0 deletions ferveo-python/test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Dkg,
DkgPublicKey,
FerveoPublicKey,
FerveoVariant,
SharedSecret,
)

Expand Down Expand Up @@ -77,3 +78,8 @@ def test_public_key_serialization():
deserialized = FerveoPublicKey.from_bytes(serialized)
assert pk == deserialized
assert len(serialized) == FerveoPublicKey.serialized_size()


def test_ferveo_variant_serialization():
assert FerveoVariant.precomputed() == "FerveoVariant::Precomputed"
assert FerveoVariant.simple() == "FerveoVariant::Simple"
34 changes: 33 additions & 1 deletion ferveo/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io;
use std::{fmt, io};

use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand Down Expand Up @@ -69,6 +69,38 @@ pub fn decrypt_with_shared_secret(
.map_err(Error::from)
}

/// The ferveo variant to use for the decryption share derivation.
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize, Copy, Clone)]
pub enum FerveoVariant {
/// The simple variant requires m of n shares to decrypt
Simple,
/// The precomputed variant requires n of n shares to decrypt
Precomputed,
}

impl fmt::Display for FerveoVariant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}

impl FerveoVariant {
pub fn as_str(&self) -> &'static str {
match self {
FerveoVariant::Simple => "FerveoVariant::Simple",
FerveoVariant::Precomputed => "FerveoVariant::Precomputed",
}
}

pub fn from_string(s: &str) -> Result<Self> {
match s {
"FerveoVariant::Simple" => Ok(FerveoVariant::Simple),
"FerveoVariant::Precomputed" => Ok(FerveoVariant::Precomputed),
_ => Err(Error::InvalidVariant(s.to_string())),
}
}
}

#[serde_as]
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DkgPublicKey(
Expand Down
22 changes: 22 additions & 0 deletions ferveo/src/bindings_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ impl From<FerveoPythonError> for PyErr {
expected, actual
))
}
Error::InvalidVariant(variant) => {
InvalidVariant::new_err(variant.to_string())
}
},
_ => default(),
}
Expand Down Expand Up @@ -128,6 +131,7 @@ create_exception!(exceptions, ValidatorsNotSorted, PyValueError);
create_exception!(exceptions, ValidatorPublicKeyMismatch, PyValueError);
create_exception!(exceptions, SerializationError, PyValueError);
create_exception!(exceptions, InvalidByteLength, PyValueError);
create_exception!(exceptions, InvalidVariant, PyValueError);

fn from_py_bytes<T: FromBytes>(bytes: &[u8]) -> PyResult<T> {
T::from_bytes(bytes)
Expand Down Expand Up @@ -267,6 +271,22 @@ pub fn decrypt_with_shared_secret(
.map_err(|err| FerveoPythonError::FerveoError(err).into())
}

#[pyclass(module = "ferveo")]
struct FerveoVariant {}

#[pymethods]
impl FerveoVariant {
#[staticmethod]
fn precomputed() -> &'static str {
api::FerveoVariant::Precomputed.as_str()
}

#[staticmethod]
fn simple() -> &'static str {
api::FerveoVariant::Simple.as_str()
}
}

#[pyclass(module = "ferveo")]
#[derive(derive_more::AsRef)]
pub struct SharedSecret(api::SharedSecret);
Expand Down Expand Up @@ -590,6 +610,7 @@ pub fn make_ferveo_py_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<DkgPublicKey>()?;
m.add_class::<SharedSecret>()?;
m.add_class::<ValidatorMessage>()?;
m.add_class::<FerveoVariant>()?;

// Exceptions
m.add(
Expand Down Expand Up @@ -645,6 +666,7 @@ pub fn make_ferveo_py_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
py.get_type::<ValidatorPublicKeyMismatch>(),
)?;
m.add("SerializationError", py.get_type::<SerializationError>())?;
m.add("InvalidVariant", py.get_type::<InvalidVariant>())?;

Ok(())
}
Expand Down
16 changes: 16 additions & 0 deletions ferveo/src/bindings_wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ macro_rules! generate_common_methods {
};
}

#[wasm_bindgen]
pub struct FerveoVariant {}

#[wasm_bindgen]
impl FerveoVariant {
#[wasm_bindgen(js_name = "precomputed", getter)]
pub fn precomputed() -> String {
api::FerveoVariant::Precomputed.as_str().to_string()
}

#[wasm_bindgen(js_name = "simple", getter)]
pub fn simple() -> String {
api::FerveoVariant::Simple.as_str().to_string()
}
}

#[derive(TryFromJsValue)]
#[wasm_bindgen]
#[derive(Clone, Debug, derive_more::AsRef, derive_more::From)]
Expand Down
3 changes: 3 additions & 0 deletions ferveo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ pub enum Error {

#[error("Invalid byte length. Expected {0}, got {1}")]
InvalidByteLength(usize, usize),

#[error("Invalid variant: {0}")]
InvalidVariant(String),
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down

0 comments on commit e8d0598

Please sign in to comment.