From 126685377ca25b9179fd76f4313a397a2771d8af Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Thu, 18 Jan 2024 15:40:57 +0100 Subject: [PATCH] refactor(dkg): refactor dkg params into a seperate struct --- ferveo-python/test/test_ferveo.py | 74 ++++++++++++++++++++ ferveo-wasm/tests/node.rs | 27 +++---- ferveo/benches/benchmarks/validity_checks.rs | 6 +- ferveo/examples/bench_primitives_size.rs | 6 +- ferveo/src/api.rs | 15 ++-- ferveo/src/bindings_python.rs | 14 +++- ferveo/src/bindings_wasm.rs | 8 +-- ferveo/src/dkg.rs | 71 ++++++++++++++++++- ferveo/src/lib.rs | 34 +++++---- ferveo/src/pvss.rs | 34 ++++----- ferveo/src/validator.rs | 1 + 11 files changed, 220 insertions(+), 70 deletions(-) diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 6f00b6df..66f147c1 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -136,6 +136,80 @@ def test_precomputed_tdec_doesnt_have_enough_messages(): ) +def test_dkg_has_min_shares(): + validators_num = 7 + shares_num = 3 + threshold = 3 + + tau = 1 + validator_keypairs = [Keypair.random() for _ in range(0, validators_num)] + validators = [ + Validator(gen_eth_addr(i), keypair.public_key()) + for i, keypair in enumerate(validator_keypairs) + ] + validators.sort(key=lambda v: v.address) + + messages = [] + for sender in validators: + dkg = Dkg( + tau=tau, + shares_num=shares_num, + security_threshold=threshold, + validators=validators, + me=sender, + ) + messages.append(ValidatorMessage(sender, dkg.generate_transcript())) + + dkg = Dkg( + tau=tau, + shares_num=shares_num, + security_threshold=threshold, + validators=validators, + me=validators[0], + ) + pvss_aggregated = dkg.aggregate_transcripts(messages) + assert pvss_aggregated.verify(shares_num, messages) + + dkg_pk_bytes = bytes(dkg.public_key) + dkg_pk = DkgPublicKey.from_bytes(dkg_pk_bytes) + + msg = "abc".encode() + aad = "my-aad".encode() + ciphertext = encrypt(msg, aad, dkg_pk) + + decryption_shares = [] + for validator, validator_keypair in zip(validators, validator_keypairs): + dkg = Dkg( + tau=tau, + shares_num=validators_num, + security_threshold=threshold, + validators=validators, + me=validator, + ) + pvss_aggregated = dkg.aggregate_transcripts(messages) + assert pvss_aggregated.verify(validators_num, messages) + + decryption_share = decryption_share_for_variant(variant, pvss_aggregated)( + dkg, ciphertext.header, aad, validator_keypair + ) + decryption_shares.append(decryption_share) + + shared_secret = combine_shares_for_variant(variant, decryption_shares) + + if variant == FerveoVariant.Simple and len(decryption_shares) < threshold: + with pytest.raises(ThresholdEncryptionError): + decrypt_with_shared_secret(ciphertext, aad, shared_secret) + return + + if variant == FerveoVariant.Precomputed and len(decryption_shares) < validators_num: + with pytest.raises(ThresholdEncryptionError): + decrypt_with_shared_secret(ciphertext, aad, shared_secret) + return + + plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret) + assert bytes(plaintext) == msg + + PARAMS = [ (1, FerveoVariant.Simple), (3, FerveoVariant.Simple), diff --git a/ferveo-wasm/tests/node.rs b/ferveo-wasm/tests/node.rs index b4234d07..4ac71429 100644 --- a/ferveo-wasm/tests/node.rs +++ b/ferveo-wasm/tests/node.rs @@ -8,8 +8,8 @@ use wasm_bindgen_test::*; type TestSetup = ( u32, - usize, - usize, + u32, + u32, Vec, Vec, ValidatorArray, @@ -21,11 +21,12 @@ type TestSetup = ( fn setup_dkg() -> TestSetup { let tau = 1; - let shares_num = 16; + let shares_num: u32 = 16; let security_threshold = shares_num * 2 / 3; - let validator_keypairs = - (0..shares_num).map(gen_keypair).collect::>(); + let validator_keypairs = (0..shares_num as usize) + .map(gen_keypair) + .collect::>(); let validators = validator_keypairs .iter() .enumerate() @@ -38,8 +39,8 @@ fn setup_dkg() -> TestSetup { let messages = validators.iter().map(|sender| { let dkg = Dkg::new( tau, - shares_num as u32, - security_threshold as u32, + shares_num, + security_threshold, &validators_js, sender, ) @@ -54,8 +55,8 @@ fn setup_dkg() -> TestSetup { let mut dkg = Dkg::new( tau, - shares_num as u32, - security_threshold as u32, + shares_num, + security_threshold, &validators_js, &validators[0], ) @@ -112,8 +113,8 @@ fn tdec_simple() { .map(|(validator, keypair)| { let mut dkg = Dkg::new( tau, - shares_num as u32, - security_threshold as u32, + shares_num, + security_threshold, &validators_js, &validator, ) @@ -166,8 +167,8 @@ fn tdec_precomputed() { .map(|(validator, keypair)| { let mut dkg = Dkg::new( tau, - shares_num as u32, - security_threshold as u32, + shares_num, + security_threshold, &validators_js, &validator, ) diff --git a/ferveo/benches/benchmarks/validity_checks.rs b/ferveo/benches/benchmarks/validity_checks.rs index a6dd9f48..cc7266f7 100644 --- a/ferveo/benches/benchmarks/validity_checks.rs +++ b/ferveo/benches/benchmarks/validity_checks.rs @@ -45,11 +45,7 @@ fn setup_dkg( let me = validators[validator].clone(); PubliclyVerifiableDkg::new( &validators, - &DkgParams { - tau: 0, - security_threshold: shares_num / 3, - shares_num, - }, + &DkgParams::new(0, shares_num / 3, shares_num).unwrap(), &me, ) .expect("Setup failed") diff --git a/ferveo/examples/bench_primitives_size.rs b/ferveo/examples/bench_primitives_size.rs index 18adf673..79afb8a4 100644 --- a/ferveo/examples/bench_primitives_size.rs +++ b/ferveo/examples/bench_primitives_size.rs @@ -80,11 +80,7 @@ fn setup_dkg( let me = validators[validator].clone(); PubliclyVerifiableDkg::new( &validators, - &DkgParams { - tau: 0, - security_threshold, - shares_num, - }, + &DkgParams::new(0, security_threshold, shares_num).unwrap(), &me, ) .expect("Setup failed") diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index af3edcd4..be8f5b1f 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -18,6 +18,9 @@ pub type PublicKey = ferveo_common::PublicKey; pub type Keypair = ferveo_common::Keypair; pub type Validator = crate::Validator; pub type Transcript = PubliclyVerifiableSS; + +// pub type ShareIndex = u32; +// pub type ValidatorMessage = (ShareIndex, Validator, Transcript); pub type ValidatorMessage = (Validator, Transcript); #[cfg(feature = "bindings-python")] @@ -203,11 +206,8 @@ impl Dkg { validators: &[Validator], me: &Validator, ) -> Result { - let dkg_params = crate::DkgParams { - tau, - security_threshold, - shares_num, - }; + let dkg_params = + crate::DkgParams::new(tau, security_threshold, shares_num)?; let dkg = crate::PubliclyVerifiableDkg::::new( validators, &dkg_params, @@ -312,7 +312,7 @@ impl AggregatedTranscript { .0 .domain .elements() - .take(dkg.0.dkg_params.shares_num as usize) + .take(dkg.0.dkg_params.shares_num() as usize) .collect(); self.0.make_decryption_share_simple_precomputed( &ciphertext_header.0, @@ -434,6 +434,7 @@ mod test_ferveo_api { (sender.clone(), dkg.generate_transcript(rng).unwrap()) }) .collect(); + (messages, validators, validator_keypairs) } @@ -647,7 +648,7 @@ mod test_ferveo_api { let local_aggregate = dkg.aggregate_transcripts(&messages).unwrap(); assert!(local_aggregate - .verify(dkg.0.dkg_params.shares_num, &messages) + .verify(dkg.0.dkg_params.shares_num(), &messages) .is_ok()); } diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index ed965f3e..00c455a3 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -93,7 +93,17 @@ impl From for PyErr { } Error::InvalidVariant(variant) => { InvalidVariant::new_err(variant.to_string()) - } + }, + Error::InvalidDkgParameters(num_shares, security_threshold) => { + InvalidDkgParameters::new_err(format!( + "num_shares: {num_shares}, security_threshold: {security_threshold}" + )) + }, + Error::InvalidShareIndex(index) => { + InvalidShareIndex::new_err(format!( + "{index}" + )) + }, }, _ => default(), } @@ -128,6 +138,8 @@ create_exception!(exceptions, ValidatorPublicKeyMismatch, PyValueError); create_exception!(exceptions, SerializationError, PyValueError); create_exception!(exceptions, InvalidByteLength, PyValueError); create_exception!(exceptions, InvalidVariant, PyValueError); +create_exception!(exceptions, InvalidDkgParameters, PyValueError); +create_exception!(exceptions, InvalidShareIndex, PyValueError); fn from_py_bytes(bytes: &[u8]) -> PyResult { T::from_bytes(bytes) diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index e412b6e0..5b03f188 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -510,15 +510,13 @@ impl AggregatedTranscript { #[wasm_bindgen] pub fn verify( &self, - shares_num: usize, + shares_num: u32, messages: &ValidatorMessageArray, ) -> JsResult { set_panic_hook(); let messages = unwrap_messages_js(messages)?; - let is_valid = self - .0 - .verify(shares_num as u32, &messages) - .map_err(map_js_err)?; + let is_valid = + self.0.verify(shares_num, &messages).map_err(map_js_err)?; Ok(is_valid) } diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index 3f7fc09d..808afd40 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -15,9 +15,50 @@ use crate::{ #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct DkgParams { - pub tau: u32, - pub security_threshold: u32, - pub shares_num: u32, + tau: u32, + security_threshold: u32, + shares_num: u32, +} + +impl DkgParams { + /// Create new DKG parameters + /// `tau` is a unique identifier for the DKG (ritual id) + /// `security_threshold` is the minimum number of shares required to reconstruct the key + /// `shares_num` is the total number of shares to be generated + /// Returns an error if the parameters are invalid + /// Parameters must hold: `shares_num` >= `security_threshold` + pub fn new( + tau: u32, + security_threshold: u32, + shares_num: u32, + ) -> Result { + if shares_num < security_threshold + || shares_num == 0 + || security_threshold == 0 + { + return Err(Error::InvalidDkgParameters( + shares_num, + security_threshold, + )); + } + Ok(Self { + tau, + security_threshold, + shares_num, + }) + } + + pub fn tau(&self) -> u32 { + self.tau + } + + pub fn security_threshold(&self) -> u32 { + self.security_threshold + } + + pub fn shares_num(&self) -> u32 { + self.shares_num + } } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -123,6 +164,7 @@ impl PubliclyVerifiableDkg { validators, state: DkgState::Sharing { accumulated_shares: 0, + // TODO: Do we need to keep track of the block number? block: 0, }, }) @@ -151,6 +193,7 @@ impl PubliclyVerifiableDkg { } } + // TODO: Make private, use `share` instead. Currently used only in bindings pub fn create_share( &self, rng: &mut R, @@ -248,6 +291,10 @@ impl PubliclyVerifiableDkg { return Err(Error::UnknownDealer(sender.clone().address)); } + // TODO: Throw error instead of silently accepting excess shares? + // if self.vss.len() < self.dkg_params.shares_num as usize { + // self.vss.insert(sender.address.clone(), pvss.clone()); + // } self.vss.insert(sender.address.clone(), pvss.clone()); // we keep track of the amount of shares seen until the security @@ -751,3 +798,21 @@ mod test_aggregation { assert!(dkg.verify_message(&sender, &aggregate).is_err()); } } + +/// Test DKG parameters +#[cfg(test)] +mod test_dkg_params { + const TAU: u32 = 0; + + #[test] + fn test_shares_num_less_than_security_threshold() { + let dkg_params = super::DkgParams::new(TAU, 4, 3); + assert!(dkg_params.is_err()); + } + + #[test] + fn test_valid_dkg_params() { + let dkg_params = super::DkgParams::new(TAU, 2, 3); + assert!(dkg_params.is_ok()); + } +} diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 59a44024..4a29f099 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -101,6 +101,12 @@ pub enum Error { #[error("Invalid variant: {0}")] InvalidVariant(String), + + #[error("Invalid DKG parameters: number of shares {0}, threshold {1}")] + InvalidDkgParameters(u32, u32), + + #[error("Invalid share index: {0}")] + InvalidShareIndex(u32), } pub type Result = std::result::Result; @@ -410,7 +416,7 @@ mod test_dkg_full { &domain_points, &dkg.pvss_params.h.into_affine(), &x_r, - dkg.dkg_params.security_threshold as usize, + dkg.dkg_params.security_threshold() as usize, rng, ); (v_addr.clone(), deltas_i) @@ -439,11 +445,13 @@ mod test_dkg_full { // Creates updated private key shares // TODO: Why not using dkg.aggregate()? let pvss_aggregated = aggregate(&dkg.vss); - pvss_aggregated.update_private_key_share_for_recovery( - &decryption_key, - validator.share_index, - updates_for_participant.as_slice(), - ) + pvss_aggregated + .update_private_key_share_for_recovery( + &decryption_key, + validator.share_index, + updates_for_participant.as_slice(), + ) + .unwrap() }) .collect(); @@ -552,7 +560,7 @@ mod test_dkg_full { let deltas_i = prepare_share_updates_for_refresh::( &domain_points, &dkg.pvss_params.h.into_affine(), - dkg.dkg_params.security_threshold as usize, + dkg.dkg_params.security_threshold() as usize, rng, ); (v_addr.clone(), deltas_i) @@ -582,11 +590,13 @@ mod test_dkg_full { // Creates updated private key shares // TODO: Why not using dkg.aggregate()? let pvss_aggregated = aggregate(&dkg.vss); - pvss_aggregated.update_private_key_share_for_recovery( - &decryption_key, - validator.share_index, - updates_for_participant.as_slice(), - ) + pvss_aggregated + .update_private_key_share_for_recovery( + &decryption_key, + validator.share_index, + updates_for_participant.as_slice(), + ) + .unwrap() }) .collect(); diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index 4f63da82..495018bf 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -40,9 +40,6 @@ pub trait Aggregate {} /// Apply trait gate to Aggregated marker struct impl Aggregate for Aggregated {} -// /// Type alias for non aggregated PVSS transcripts -// pub type Pvss = PubliclyVerifiableSS; - /// Type alias for aggregated PVSS transcripts pub type AggregatedPvss = PubliclyVerifiableSS; @@ -138,7 +135,7 @@ impl PubliclyVerifiableSS { ) -> Result { let phi = SecretPolynomial::::new( s, - (dkg.dkg_params.security_threshold - 1) as usize, + (dkg.dkg_params.security_threshold() - 1) as usize, rng, ); @@ -311,19 +308,19 @@ impl PubliclyVerifiableSS { &self, validator_decryption_key: &E::ScalarField, share_index: usize, - ) -> PrivateKeyShare { + ) -> Result> { // Decrypt private key shares https://nikkolasg.github.io/ferveo/pvss.html#validator-decryption-of-private-key-shares let private_key_share = self .shares .get(share_index) - .unwrap() + .ok_or(Error::InvalidShareIndex(share_index as u32))? .mul( validator_decryption_key .inverse() .expect("Validator decryption key must have an inverse"), ) .into_affine(); - PrivateKeyShare { private_key_share } + Ok(PrivateKeyShare { private_key_share }) } pub fn make_decryption_share_simple( @@ -335,7 +332,7 @@ impl PubliclyVerifiableSS { g_inv: &E::G1Prepared, ) -> Result> { let private_key_share = self - .decrypt_private_key_share(validator_decryption_key, share_index); + .decrypt_private_key_share(validator_decryption_key, share_index)?; DecryptionShareSimple::create( validator_decryption_key, &private_key_share, @@ -356,7 +353,7 @@ impl PubliclyVerifiableSS { g_inv: &E::G1Prepared, ) -> Result> { let private_key_share = self - .decrypt_private_key_share(validator_decryption_key, share_index); + .decrypt_private_key_share(validator_decryption_key, share_index)?; // We use the `prepare_combine_simple` function to precompute the lagrange coefficients let lagrange_coeffs = prepare_combine_simple::(domain_points); @@ -379,13 +376,16 @@ impl PubliclyVerifiableSS { validator_decryption_key: &E::ScalarField, share_index: usize, share_updates: &[E::G2], - ) -> PrivateKeyShare { + ) -> Result> { // Retrieves their private key share let private_key_share = self - .decrypt_private_key_share(validator_decryption_key, share_index); + .decrypt_private_key_share(validator_decryption_key, share_index)?; // And updates their share - apply_updates_to_private_share::(&private_key_share, share_updates) + Ok(apply_updates_to_private_share::( + &private_key_share, + share_updates, + )) } } @@ -482,7 +482,7 @@ mod test_pvss { // Check that a polynomial of the correct degree was created assert_eq!( pvss.coeffs.len(), - dkg.dkg_params.security_threshold as usize + dkg.dkg_params.security_threshold() as usize ); // Check that the correct number of shares were created assert_eq!(pvss.shares.len(), dkg.validators.len()); @@ -555,11 +555,7 @@ mod test_pvss { // And because of that the DKG should fail let result = PubliclyVerifiableDkg::new( &validators, - &DkgParams { - tau: 0, - security_threshold, - shares_num, - }, + &DkgParams::new(0, security_threshold, shares_num).unwrap(), &me, ); assert!(result.is_err()); @@ -578,7 +574,7 @@ mod test_pvss { // Check that a polynomial of the correct degree was created assert_eq!( aggregate.coeffs.len(), - dkg.dkg_params.security_threshold as usize + dkg.dkg_params.security_threshold() as usize ); // Check that the correct number of shares were created assert_eq!(aggregate.shares.len(), dkg.validators.len()); diff --git a/ferveo/src/validator.rs b/ferveo/src/validator.rs index 7b014266..d931ca06 100644 --- a/ferveo/src/validator.rs +++ b/ferveo/src/validator.rs @@ -54,6 +54,7 @@ impl PartialOrd for Validator { } impl Ord for Validator { + // Validators are ordered by their address only fn cmp(&self, other: &Self) -> Ordering { self.address.cmp(&other.address) }