Skip to content

Commit

Permalink
feat: expose ferveo variant in bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Jul 3, 2023
1 parent 4939e79 commit 60cad20
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ferveo-python/ferveo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DkgPublicKey,
SharedSecret,
ValidatorMessage,
FerveoVariant,
ThresholdEncryptionError,
InvalidShareNumberParameter,
InvalidDkgStateToDeal,
Expand All @@ -32,4 +33,5 @@
ValidatorsNotSorted,
ValidatorPublicKeyMismatch,
SerializationError,
InvalidVariant,
)
14 changes: 14 additions & 0 deletions ferveo-python/ferveo/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ class SharedSecret:
...


# TODO: Figure out how to make this a proper enum or expose a plain class
class FerveoVariant:

def from_string(self, variant: str) -> FerveoVariant:
...

def __str__(self) -> str:
...


def encrypt(message: bytes, add: bytes, dkg_public_key: DkgPublicKey) -> Ciphertext:
...

Expand Down Expand Up @@ -260,3 +270,7 @@ class ValidatorPublicKeyMismatch(Exception):

class SerializationError(Exception):
pass


class InvalidVariant(Exception):
pass
7 changes: 6 additions & 1 deletion ferveo-python/test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Dkg,
DkgPublicKey,
FerveoPublicKey,
FerveoVariant,
)


Expand Down Expand Up @@ -64,7 +65,11 @@ def test_dkg_public_key_serialization():
assert len(serialized) == DkgPublicKey.serialized_size()


def test_dkg_public_key_serialization():
def test_public_key_serialization():
pk = make_pk()
serialized = bytes(pk)
assert len(serialized) == FerveoPublicKey.serialized_size()

# TODO: Consider different API, FerveoVariant.Precomputed etc.
def test_ferveo_variant_serialization():
variant = FerveoVariant('FerveoVariant::Precomputed')
32 changes: 31 additions & 1 deletion ferveo/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io;
use std::{fmt, io};

use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
Expand Down Expand Up @@ -69,6 +69,36 @@ pub fn decrypt_with_shared_secret(
.map_err(Error::from)
}

#[serde_as]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum FerveoVariant {
Simple = 0_isize,
Precomputed = 1_isize,
}

impl fmt::Display for FerveoVariant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}

impl FerveoVariant {
pub fn as_str(&self) -> &'static str {
match self {
FerveoVariant::Simple => "FerveoVariant::Simple",
FerveoVariant::Precomputed => "FerveoVariant::Precomputed",
}
}

pub fn from_string(s: &str) -> Result<Self> {
match s {
"FerveoVariant::Simple" => Ok(FerveoVariant::Simple),
"FerveoVariant::Precomputed" => Ok(FerveoVariant::Precomputed),
_ => Err(Error::InvalidVariant(s.to_string())),
}
}
}

#[serde_as]
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DkgPublicKey(
Expand Down
43 changes: 42 additions & 1 deletion ferveo/src/bindings_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::{
create_exception,
exceptions::{PyException, PyRuntimeError, PyValueError},
prelude::*,
types::{PyBytes, PyUnicode},
types::{PyBytes, PyString, PyUnicode},
PyClass,
};
use rand::thread_rng;
Expand Down Expand Up @@ -94,6 +94,9 @@ impl From<FerveoPythonError> for PyErr {
expected, actual
))
}
Error::InvalidVariant(variant) => {
InvalidVariant::new_err(variant.to_string())
}
},
_ => default(),
}
Expand Down Expand Up @@ -128,6 +131,7 @@ create_exception!(exceptions, ValidatorsNotSorted, PyValueError);
create_exception!(exceptions, ValidatorPublicKeyMismatch, PyValueError);
create_exception!(exceptions, SerializationError, PyValueError);
create_exception!(exceptions, InvalidByteLength, PyValueError);
create_exception!(exceptions, InvalidVariant, PyValueError);

fn from_py_bytes<T: FromBytes>(bytes: &[u8]) -> PyResult<T> {
T::from_bytes(bytes)
Expand Down Expand Up @@ -278,6 +282,41 @@ pub fn decrypt_with_shared_secret(
.map_err(|err| FerveoPythonError::FerveoError(err).into())
}

#[pyclass(module = "ferveo")]
pub enum FerveoVariant {
Simple = api::FerveoVariant::Simple as isize,
Precomputed = api::FerveoVariant::Precomputed as isize,
}

impl From<api::FerveoVariant> for FerveoVariant {
fn from(variant: api::FerveoVariant) -> Self {
match variant {
api::FerveoVariant::Simple => FerveoVariant::Simple,
api::FerveoVariant::Precomputed => FerveoVariant::Precomputed,
}
}
}

#[pymethods]
impl FerveoVariant {
#[new]
pub fn new(s: &PyString) -> PyResult<Self> {
api::FerveoVariant::from_string(s.to_string().as_str())
.map_err(|err| FerveoPythonError::FerveoError(err).into())
.map(|variant| variant.into())
}

#[getter]
pub fn __str__(&self) -> String {
match self {
FerveoVariant::Simple => api::FerveoVariant::Simple.to_string(),
FerveoVariant::Precomputed => {
api::FerveoVariant::Precomputed.to_string()
}
}
}
}

#[pyclass(module = "ferveo")]
#[derive(derive_more::AsRef)]
pub struct SharedSecret(api::SharedSecret);
Expand Down Expand Up @@ -600,6 +639,7 @@ pub fn make_ferveo_py_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<DkgPublicKey>()?;
m.add_class::<SharedSecret>()?;
m.add_class::<ValidatorMessage>()?;
m.add_class::<FerveoVariant>()?;

// Exceptions
m.add(
Expand Down Expand Up @@ -655,6 +695,7 @@ pub fn make_ferveo_py_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
py.get_type::<ValidatorPublicKeyMismatch>(),
)?;
m.add("SerializationError", py.get_type::<SerializationError>())?;
m.add("InvalidVariant", py.get_type::<InvalidVariant>())?;

Ok(())
}
Expand Down
37 changes: 37 additions & 0 deletions ferveo/src/bindings_wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,43 @@ macro_rules! generate_common_methods {
};
}

#[derive(TryFromJsValue)]
#[wasm_bindgen]
#[derive(Clone, Debug, derive_more::AsRef, derive_more::From)]
pub struct FerveoVariant {
variant: api::FerveoVariant,
}

#[wasm_bindgen]
impl FerveoVariant {
#[wasm_bindgen(js_name = "fromString")]
pub fn from_string(s: &str) -> JsResult<FerveoVariant> {
let variant = api::FerveoVariant::from_string(s).map_err(map_js_err)?;
Ok(Self { variant })
}

// Allow `to_string` here because we want to expose it to bindings
#[allow(clippy::inherent_to_string)]
#[wasm_bindgen(js_name = "toString")]
pub fn to_string(&self) -> String {
self.variant.to_string()
}

#[wasm_bindgen(js_name = "Precomputed", getter)]
pub fn precomputed() -> FerveoVariant {
Self {
variant: api::FerveoVariant::Precomputed,
}
}

#[wasm_bindgen(js_name = "Simple", getter)]
pub fn simple() -> FerveoVariant {
Self {
variant: api::FerveoVariant::Simple,
}
}
}

#[derive(TryFromJsValue)]
#[wasm_bindgen]
#[derive(Clone, Debug, derive_more::AsRef, derive_more::From)]
Expand Down
3 changes: 3 additions & 0 deletions ferveo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ pub enum Error {

#[error("Invalid byte length. Expected {0}, got {1}")]
InvalidByteLength(usize, usize),

#[error("Invalid variant: {0}")]
InvalidVariant(String),
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down

0 comments on commit 60cad20

Please sign in to comment.