Skip to content

Commit

Permalink
refactor(dkg): refactor dkg params into a seperate struct
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Jan 18, 2024
1 parent 87c5f34 commit 1266853
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 70 deletions.
74 changes: 74 additions & 0 deletions ferveo-python/test/test_ferveo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,80 @@ def test_precomputed_tdec_doesnt_have_enough_messages():
)


def test_dkg_has_min_shares():
validators_num = 7
shares_num = 3
threshold = 3

tau = 1
validator_keypairs = [Keypair.random() for _ in range(0, validators_num)]
validators = [
Validator(gen_eth_addr(i), keypair.public_key())
for i, keypair in enumerate(validator_keypairs)
]
validators.sort(key=lambda v: v.address)

messages = []
for sender in validators:
dkg = Dkg(
tau=tau,
shares_num=shares_num,
security_threshold=threshold,
validators=validators,
me=sender,
)
messages.append(ValidatorMessage(sender, dkg.generate_transcript()))

dkg = Dkg(
tau=tau,
shares_num=shares_num,
security_threshold=threshold,
validators=validators,
me=validators[0],
)
pvss_aggregated = dkg.aggregate_transcripts(messages)
assert pvss_aggregated.verify(shares_num, messages)

dkg_pk_bytes = bytes(dkg.public_key)
dkg_pk = DkgPublicKey.from_bytes(dkg_pk_bytes)

msg = "abc".encode()
aad = "my-aad".encode()
ciphertext = encrypt(msg, aad, dkg_pk)

decryption_shares = []
for validator, validator_keypair in zip(validators, validator_keypairs):
dkg = Dkg(
tau=tau,
shares_num=validators_num,
security_threshold=threshold,
validators=validators,
me=validator,
)
pvss_aggregated = dkg.aggregate_transcripts(messages)
assert pvss_aggregated.verify(validators_num, messages)

decryption_share = decryption_share_for_variant(variant, pvss_aggregated)(
dkg, ciphertext.header, aad, validator_keypair
)
decryption_shares.append(decryption_share)

shared_secret = combine_shares_for_variant(variant, decryption_shares)

if variant == FerveoVariant.Simple and len(decryption_shares) < threshold:
with pytest.raises(ThresholdEncryptionError):
decrypt_with_shared_secret(ciphertext, aad, shared_secret)
return

if variant == FerveoVariant.Precomputed and len(decryption_shares) < validators_num:
with pytest.raises(ThresholdEncryptionError):
decrypt_with_shared_secret(ciphertext, aad, shared_secret)
return

plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret)
assert bytes(plaintext) == msg


PARAMS = [
(1, FerveoVariant.Simple),
(3, FerveoVariant.Simple),
Expand Down
27 changes: 14 additions & 13 deletions ferveo-wasm/tests/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use wasm_bindgen_test::*;

type TestSetup = (
u32,
usize,
usize,
u32,
u32,
Vec<Keypair>,
Vec<Validator>,
ValidatorArray,
Expand All @@ -21,11 +21,12 @@ type TestSetup = (

fn setup_dkg() -> TestSetup {
let tau = 1;
let shares_num = 16;
let shares_num: u32 = 16;
let security_threshold = shares_num * 2 / 3;

let validator_keypairs =
(0..shares_num).map(gen_keypair).collect::<Vec<Keypair>>();
let validator_keypairs = (0..shares_num as usize)
.map(gen_keypair)
.collect::<Vec<Keypair>>();
let validators = validator_keypairs
.iter()
.enumerate()
Expand All @@ -38,8 +39,8 @@ fn setup_dkg() -> TestSetup {
let messages = validators.iter().map(|sender| {
let dkg = Dkg::new(
tau,
shares_num as u32,
security_threshold as u32,
shares_num,
security_threshold,
&validators_js,
sender,
)
Expand All @@ -54,8 +55,8 @@ fn setup_dkg() -> TestSetup {

let mut dkg = Dkg::new(
tau,
shares_num as u32,
security_threshold as u32,
shares_num,
security_threshold,
&validators_js,
&validators[0],
)
Expand Down Expand Up @@ -112,8 +113,8 @@ fn tdec_simple() {
.map(|(validator, keypair)| {
let mut dkg = Dkg::new(
tau,
shares_num as u32,
security_threshold as u32,
shares_num,
security_threshold,
&validators_js,
&validator,
)
Expand Down Expand Up @@ -166,8 +167,8 @@ fn tdec_precomputed() {
.map(|(validator, keypair)| {
let mut dkg = Dkg::new(
tau,
shares_num as u32,
security_threshold as u32,
shares_num,
security_threshold,
&validators_js,
&validator,
)
Expand Down
6 changes: 1 addition & 5 deletions ferveo/benches/benchmarks/validity_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ fn setup_dkg(
let me = validators[validator].clone();
PubliclyVerifiableDkg::new(
&validators,
&DkgParams {
tau: 0,
security_threshold: shares_num / 3,
shares_num,
},
&DkgParams::new(0, shares_num / 3, shares_num).unwrap(),
&me,
)
.expect("Setup failed")
Expand Down
6 changes: 1 addition & 5 deletions ferveo/examples/bench_primitives_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ fn setup_dkg(
let me = validators[validator].clone();
PubliclyVerifiableDkg::new(
&validators,
&DkgParams {
tau: 0,
security_threshold,
shares_num,
},
&DkgParams::new(0, security_threshold, shares_num).unwrap(),
&me,
)
.expect("Setup failed")
Expand Down
15 changes: 8 additions & 7 deletions ferveo/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub type PublicKey = ferveo_common::PublicKey<E>;
pub type Keypair = ferveo_common::Keypair<E>;
pub type Validator = crate::Validator<E>;
pub type Transcript = PubliclyVerifiableSS<E>;

// pub type ShareIndex = u32;
// pub type ValidatorMessage = (ShareIndex, Validator, Transcript);
pub type ValidatorMessage = (Validator, Transcript);

#[cfg(feature = "bindings-python")]
Expand Down Expand Up @@ -203,11 +206,8 @@ impl Dkg {
validators: &[Validator],
me: &Validator,
) -> Result<Self> {
let dkg_params = crate::DkgParams {
tau,
security_threshold,
shares_num,
};
let dkg_params =
crate::DkgParams::new(tau, security_threshold, shares_num)?;
let dkg = crate::PubliclyVerifiableDkg::<E>::new(
validators,
&dkg_params,
Expand Down Expand Up @@ -312,7 +312,7 @@ impl AggregatedTranscript {
.0
.domain
.elements()
.take(dkg.0.dkg_params.shares_num as usize)
.take(dkg.0.dkg_params.shares_num() as usize)
.collect();
self.0.make_decryption_share_simple_precomputed(
&ciphertext_header.0,
Expand Down Expand Up @@ -434,6 +434,7 @@ mod test_ferveo_api {
(sender.clone(), dkg.generate_transcript(rng).unwrap())
})
.collect();

(messages, validators, validator_keypairs)
}

Expand Down Expand Up @@ -647,7 +648,7 @@ mod test_ferveo_api {

let local_aggregate = dkg.aggregate_transcripts(&messages).unwrap();
assert!(local_aggregate
.verify(dkg.0.dkg_params.shares_num, &messages)
.verify(dkg.0.dkg_params.shares_num(), &messages)
.is_ok());
}

Expand Down
14 changes: 13 additions & 1 deletion ferveo/src/bindings_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,17 @@ impl From<FerveoPythonError> for PyErr {
}
Error::InvalidVariant(variant) => {
InvalidVariant::new_err(variant.to_string())
}
},
Error::InvalidDkgParameters(num_shares, security_threshold) => {
InvalidDkgParameters::new_err(format!(
"num_shares: {num_shares}, security_threshold: {security_threshold}"
))
},
Error::InvalidShareIndex(index) => {
InvalidShareIndex::new_err(format!(
"{index}"
))
},
},
_ => default(),
}
Expand Down Expand Up @@ -128,6 +138,8 @@ create_exception!(exceptions, ValidatorPublicKeyMismatch, PyValueError);
create_exception!(exceptions, SerializationError, PyValueError);
create_exception!(exceptions, InvalidByteLength, PyValueError);
create_exception!(exceptions, InvalidVariant, PyValueError);
create_exception!(exceptions, InvalidDkgParameters, PyValueError);
create_exception!(exceptions, InvalidShareIndex, PyValueError);

fn from_py_bytes<T: FromBytes>(bytes: &[u8]) -> PyResult<T> {
T::from_bytes(bytes)
Expand Down
8 changes: 3 additions & 5 deletions ferveo/src/bindings_wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,13 @@ impl AggregatedTranscript {
#[wasm_bindgen]
pub fn verify(
&self,
shares_num: usize,
shares_num: u32,
messages: &ValidatorMessageArray,
) -> JsResult<bool> {
set_panic_hook();
let messages = unwrap_messages_js(messages)?;
let is_valid = self
.0
.verify(shares_num as u32, &messages)
.map_err(map_js_err)?;
let is_valid =
self.0.verify(shares_num, &messages).map_err(map_js_err)?;
Ok(is_valid)
}

Expand Down
71 changes: 68 additions & 3 deletions ferveo/src/dkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,50 @@ use crate::{

#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct DkgParams {
pub tau: u32,
pub security_threshold: u32,
pub shares_num: u32,
tau: u32,
security_threshold: u32,
shares_num: u32,
}

impl DkgParams {
/// Create new DKG parameters
/// `tau` is a unique identifier for the DKG (ritual id)
/// `security_threshold` is the minimum number of shares required to reconstruct the key
/// `shares_num` is the total number of shares to be generated
/// Returns an error if the parameters are invalid
/// Parameters must hold: `shares_num` >= `security_threshold`
pub fn new(
tau: u32,
security_threshold: u32,
shares_num: u32,
) -> Result<Self> {
if shares_num < security_threshold
|| shares_num == 0
|| security_threshold == 0
{
return Err(Error::InvalidDkgParameters(
shares_num,
security_threshold,
));
}
Ok(Self {
tau,
security_threshold,
shares_num,
})
}

pub fn tau(&self) -> u32 {
self.tau
}

pub fn security_threshold(&self) -> u32 {
self.security_threshold
}

pub fn shares_num(&self) -> u32 {
self.shares_num
}
}

#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
Expand Down Expand Up @@ -123,6 +164,7 @@ impl<E: Pairing> PubliclyVerifiableDkg<E> {
validators,
state: DkgState::Sharing {
accumulated_shares: 0,
// TODO: Do we need to keep track of the block number?
block: 0,
},
})
Expand Down Expand Up @@ -151,6 +193,7 @@ impl<E: Pairing> PubliclyVerifiableDkg<E> {
}
}

// TODO: Make private, use `share` instead. Currently used only in bindings
pub fn create_share<R: RngCore>(
&self,
rng: &mut R,
Expand Down Expand Up @@ -248,6 +291,10 @@ impl<E: Pairing> PubliclyVerifiableDkg<E> {
return Err(Error::UnknownDealer(sender.clone().address));
}

// TODO: Throw error instead of silently accepting excess shares?
// if self.vss.len() < self.dkg_params.shares_num as usize {
// self.vss.insert(sender.address.clone(), pvss.clone());
// }
self.vss.insert(sender.address.clone(), pvss.clone());

// we keep track of the amount of shares seen until the security
Expand Down Expand Up @@ -751,3 +798,21 @@ mod test_aggregation {
assert!(dkg.verify_message(&sender, &aggregate).is_err());
}
}

/// Test DKG parameters
#[cfg(test)]
mod test_dkg_params {
const TAU: u32 = 0;

#[test]
fn test_shares_num_less_than_security_threshold() {
let dkg_params = super::DkgParams::new(TAU, 4, 3);
assert!(dkg_params.is_err());
}

#[test]
fn test_valid_dkg_params() {
let dkg_params = super::DkgParams::new(TAU, 2, 3);
assert!(dkg_params.is_ok());
}
}
Loading

0 comments on commit 1266853

Please sign in to comment.