diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 45afe800..8818fc11 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -39,7 +39,11 @@ def combine_shares_for_variant(v: FerveoVariant, decryption_shares): def scenario_for_variant( - variant: FerveoVariant, shares_num, validators_num, threshold, dec_shares_to_use + variant: FerveoVariant, + shares_num, + validators_num, + threshold, + dec_shares_to_use ): if variant not in [FerveoVariant.Simple, FerveoVariant.Precomputed]: raise ValueError("Unknown variant: " + variant) @@ -47,11 +51,8 @@ def scenario_for_variant( if validators_num < shares_num: raise ValueError("validators_num must be >= shares_num") - # TODO: Validate that - # if variant == FerveoVariant.Precomputed and dec_shares_to_use != validators_num: - # raise ValueError( - # "In precomputed variant, dec_shares_to_use must be equal to validators_num" - # ) + if shares_num < threshold: + raise ValueError("shares_num must be >= threshold") tau = 1 validator_keypairs = [Keypair.random() for _ in range(0, validators_num)] @@ -90,6 +91,8 @@ def scenario_for_variant( client_aggregate = AggregatedTranscript(messages) assert client_aggregate.verify(validators_num, messages) + # At this point, DKG is done and we are proceeding to threshold decryption + # Client creates a ciphertext and requests decryption shares from validators msg = "abc".encode() aad = "my-aad".encode() @@ -122,12 +125,7 @@ def scenario_for_variant( # Client combines the decryption shares and decrypts the ciphertext shared_secret = combine_shares_for_variant(variant, decryption_shares) - if variant == FerveoVariant.Simple and len(decryption_shares) < threshold: - with pytest.raises(ThresholdEncryptionError): - decrypt_with_shared_secret(ciphertext, aad, shared_secret) - return - - if variant == FerveoVariant.Precomputed and len(decryption_shares) < threshold: + if len(decryption_shares) < threshold: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return @@ -152,39 +150,42 @@ def test_simple_tdec_has_enough_messages(): def test_simple_tdec_doesnt_have_enough_messages(): shares_num = 4 threshold = shares_num - 1 + dec_shares_to_use = threshold - 1 for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Simple, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=validators_num - 1, + dec_shares_to_use=dec_shares_to_use, ) def test_precomputed_tdec_has_enough_messages(): shares_num = 4 - threshold = shares_num # in precomputed variant, we need all shares + threshold = shares_num - 1 + dec_shares_to_use = threshold for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Precomputed, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=validators_num, + dec_shares_to_use=dec_shares_to_use, ) def test_precomputed_tdec_doesnt_have_enough_messages(): shares_num = 4 - threshold = shares_num # in precomputed variant, we need all shares + threshold = shares_num - 1 + dec_shares_to_use = threshold - 1 for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Simple, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=threshold - 1, + dec_shares_to_use=dec_shares_to_use, ) diff --git a/ferveo-tdec/benches/tpke.rs b/ferveo-tdec/benches/tpke.rs index b7a5b8f7..e74f1a7b 100644 --- a/ferveo-tdec/benches/tpke.rs +++ b/ferveo-tdec/benches/tpke.rs @@ -105,7 +105,7 @@ impl SetupSimple { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, privkey, contexts) = - setup_simple::(threshold, shares_num, rng); + setup_simple::(shares_num, threshold, rng); // Ciphertext.commitment is already computed to match U let ciphertext = @@ -200,6 +200,9 @@ pub fn bench_create_decryption_share(c: &mut Criterion) { }; let simple_precomputed = { let setup = SetupSimple::new(shares_num, MSG_SIZE_CASES[0], rng); + // TODO: Use threshold instead of shares_num + let selected_participants = (0..shares_num).collect::>(); + move || { black_box( setup @@ -209,6 +212,7 @@ pub fn bench_create_decryption_share(c: &mut Criterion) { context.create_share_precomputed( &setup.shared.ciphertext.header().unwrap(), &setup.shared.aad, + &selected_participants, ) }) .collect::>(), @@ -295,6 +299,8 @@ pub fn bench_share_combine(c: &mut Criterion) { }; let simple_precomputed = { let setup = SetupSimple::new(shares_num, MSG_SIZE_CASES[0], rng); + // TODO: Use threshold instead of shares_num + let selected_participants = (0..shares_num).collect::>(); let decryption_shares: Vec<_> = setup .contexts @@ -304,6 +310,7 @@ pub fn bench_share_combine(c: &mut Criterion) { .create_share_precomputed( &setup.shared.ciphertext.header().unwrap(), &setup.shared.aad, + &selected_participants, ) .unwrap() }) diff --git a/ferveo-tdec/src/context.rs b/ferveo-tdec/src/context.rs index ed7faee0..ba697917 100644 --- a/ferveo-tdec/src/context.rs +++ b/ferveo-tdec/src/context.rs @@ -92,13 +92,14 @@ impl PrivateDecryptionContextSimple { &self, ciphertext_header: &CiphertextHeader, aad: &[u8], + selected_participants: &[usize], ) -> Result> { - let domain = self - .public_decryption_contexts + let selected_domain_points = selected_participants .iter() - .map(|c| c.domain) + .map(|i| self.public_decryption_contexts[*i].domain) .collect::>(); - let lagrange_coeffs = prepare_combine_simple::(&domain); + let lagrange_coeffs = + prepare_combine_simple::(&selected_domain_points); DecryptionSharePrecomputed::create( self.index, diff --git a/ferveo-tdec/src/lib.rs b/ferveo-tdec/src/lib.rs index e491bba7..e0086dbf 100644 --- a/ferveo-tdec/src/lib.rs +++ b/ferveo-tdec/src/lib.rs @@ -175,8 +175,8 @@ pub mod test_common { } pub fn setup_simple( - threshold: usize, shares_num: usize, + threshold: usize, rng: &mut impl rand::Rng, ) -> ( PublicKey, @@ -264,17 +264,17 @@ pub mod test_common { pub fn setup_precomputed( shares_num: usize, + threshold: usize, rng: &mut impl rand::Rng, ) -> ( PublicKey, PrivateKeyShare, Vec>, ) { - // In precomputed variant, the security threshold is equal to the number of shares - setup_simple::(shares_num, shares_num, rng) + setup_simple::(shares_num, threshold, rng) } - pub fn create_shared_secret( + pub fn create_shared_secret_simple( pub_contexts: &[PublicDecryptionContextSimple], decryption_shares: &[DecryptionShareSimple], ) -> SharedSecret { @@ -291,8 +291,12 @@ mod tests { use ark_ec::{pairing::Pairing, AffineRepr, CurveGroup}; use ark_std::{test_rng, UniformRand}; use ferveo_common::{FromBytes, ToBytes}; + use rand::seq::IteratorRandom; - use crate::test_common::{create_shared_secret, setup_simple, *}; + use crate::{ + api::DecryptionSharePrecomputed, + test_common::{create_shared_secret_simple, setup_simple, *}, + }; type E = ark_bls12_381::Bls12_381; type TargetField = ::TargetField; @@ -378,7 +382,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, rng); + setup_simple::(shares_num, threshold, rng); let ciphertext = encrypt::(SecretBox::new(msg), aad, &pubkey, rng).unwrap(); @@ -447,7 +451,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, &mut rng); + setup_simple::(shares_num, threshold, &mut rng); let g_inv = &contexts[0].setup_params.g_inv; let ciphertext = @@ -462,10 +466,10 @@ mod tests { }) .take(threshold) .collect(); - let pub_contexts = + let selected_contexts = contexts[0].public_decryption_contexts[..threshold].to_vec(); let shared_secret = - create_shared_secret(&pub_contexts, &decryption_shares); + create_shared_secret_simple(&selected_contexts, &decryption_shares); test_ciphertext_validation_fails( &msg, @@ -476,13 +480,18 @@ mod tests { ); // If we use less than threshold shares, we should fail - let decryption_shares = decryption_shares[..threshold - 1].to_vec(); - let pub_contexts = pub_contexts[..threshold - 1].to_vec(); - let shared_secret = - create_shared_secret(&pub_contexts, &decryption_shares); - - let result = - decrypt_with_shared_secret(&ciphertext, aad, &shared_secret, g_inv); + let not_enough_dec_shares = decryption_shares[..threshold - 1].to_vec(); + let not_enough_contexts = selected_contexts[..threshold - 1].to_vec(); + let bash_shared_secret = create_shared_secret_simple( + ¬_enough_contexts, + ¬_enough_dec_shares, + ); + let result = decrypt_with_shared_secret( + &ciphertext, + aad, + &bash_shared_secret, + g_inv, + ); assert!(result.is_err()); } @@ -490,30 +499,39 @@ mod tests { fn tdec_precomputed_variant_e2e() { let mut rng = &mut test_rng(); let shares_num = 16; + let threshold = shares_num * 2 / 3; let msg = "my-msg".as_bytes().to_vec(); let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_precomputed::(shares_num, &mut rng); + setup_precomputed::(shares_num, threshold, &mut rng); let g_inv = &contexts[0].setup_params.g_inv; let ciphertext = encrypt::(SecretBox::new(msg.clone()), aad, &pubkey, rng) .unwrap(); - let decryption_shares: Vec<_> = contexts + let selected_participants = + (0..threshold).choose_multiple(rng, threshold); + let selected_contexts = contexts + .iter() + .filter(|c| selected_participants.contains(&c.index)) + .cloned() + .collect::>(); + + let decryption_shares = selected_contexts .iter() .map(|context| { context .create_share_precomputed( &ciphertext.header().unwrap(), aad, + &selected_participants, ) .unwrap() }) - .collect(); + .collect::>(); let shared_secret = share_combine_precomputed::(&decryption_shares); - test_ciphertext_validation_fails( &msg, aad, @@ -522,19 +540,17 @@ mod tests { g_inv, ); - // Note that in this variant, if we use less than `share_num` shares, we will get a - // decryption error. - - let not_enough_shares = &decryption_shares[0..shares_num - 1]; - let bad_shared_secret = - share_combine_precomputed::(not_enough_shares); - assert!(decrypt_with_shared_secret( + // If we use less than threshold shares, we should fail + let not_enough_dec_shares = decryption_shares[..threshold - 1].to_vec(); + let bash_shared_secret = + share_combine_precomputed(¬_enough_dec_shares); + let result = decrypt_with_shared_secret( &ciphertext, aad, - &bad_shared_secret, + &bash_shared_secret, g_inv, - ) - .is_err()); + ); + assert!(result.is_err()); } #[test] @@ -546,7 +562,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, &mut rng); + setup_simple::(shares_num, threshold, &mut rng); let ciphertext = encrypt::(SecretBox::new(msg), aad, &pubkey, rng).unwrap(); diff --git a/ferveo-wasm/tests/node.rs b/ferveo-wasm/tests/node.rs index d3a5ea43..bbfd1de4 100644 --- a/ferveo-wasm/tests/node.rs +++ b/ferveo-wasm/tests/node.rs @@ -167,6 +167,8 @@ fn tdec_precomputed() { ciphertext, ) = setup_dkg(shares_num, validators_num, security_threshold); + // TODO: Adjust the subset of validators used by the client + // Having aggregated the transcripts, the validators can now create decryption shares let decryption_shares = zip_eq(validators, validator_keypairs) .map(|(validator, keypair)| { @@ -189,6 +191,7 @@ fn tdec_precomputed() { &ciphertext.header().unwrap(), &aad, &keypair, + &validators_js, ) .unwrap() }) diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index 7ab62c56..b04908ad 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -308,22 +308,23 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators: &[Validator], ) -> Result { - // Prevent users from using the precomputed variant with improper DKG parameters - if dkg.0.dkg_params.shares_num() - != dkg.0.dkg_params.security_threshold() - { - return Err(Error::InvalidDkgParametersForPrecomputedVariant( - dkg.0.dkg_params.shares_num(), - dkg.0.dkg_params.security_threshold(), - )); - } - self.0.aggregate.create_decryption_share_simple_precomputed( + let selected_domain_points = selected_validators + .iter() + .filter_map(|v| { + dkg.0 + .get_domain_point(v.share_index) + .ok() + .map(|domain_point| (v.share_index, domain_point)) + }) + .collect::>>(); + self.0.aggregate.create_decryption_share_precomputed( &ciphertext_header.0, aad, validator_keypair, dkg.0.me.share_index, - &dkg.0.domain_points(), + &selected_domain_points, ) } @@ -544,22 +545,21 @@ impl PrivateKeyShare { } /// Make a decryption share (precomputed variant) for a given ciphertext - pub fn create_decryption_share_simple_precomputed( + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points: &HashMap, ) -> Result { - let share = self.0.create_decryption_share_simple_precomputed( + self.0.create_decryption_share_precomputed( &ciphertext_header.0, aad, validator_keypair, share_index, domain_points, - )?; - Ok(share) + ) } pub fn to_bytes(&self) -> Result> { @@ -641,10 +641,8 @@ mod test_ferveo_api { #[test_case(7, 7; "number of shares (validators) is not a power of 2")] #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_server_api_tdec_precomputed(shares_num: u32, validators_num: u32) { + let security_threshold = shares_num * 2 / 3; let rng = &mut StdRng::seed_from_u64(0); - - // In precomputed variant, the security threshold is equal to the number of shares - let security_threshold = shares_num; let (messages, validators, validator_keypairs) = make_test_inputs( rng, TAU, @@ -660,47 +658,59 @@ mod test_ferveo_api { let dkg = Dkg::new(TAU, shares_num, security_threshold, &validators, &me) .unwrap(); - let pvss_aggregated = dkg.aggregate_transcripts(messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, messages).unwrap()); + let local_aggregate = dkg.aggregate_transcripts(messages).unwrap(); + assert!(local_aggregate.verify(validators_num, messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key - let dkg_public_key = pvss_aggregated.public_key(); + let dkg_public_key = local_aggregate.public_key(); // In the meantime, the client creates a ciphertext and decryption request let ciphertext = encrypt(SecretBox::new(MSG.to_vec()), AAD, &dkg_public_key) .unwrap(); + // In precomputed variant, client selects a specific subset of validators to create + // decryption shares + let selected_validators: Vec<_> = validators + .choose_multiple(rng, security_threshold as usize) + .cloned() + .collect(); + // Having aggregated the transcripts, the validators can now create decryption shares - let mut decryption_shares: Vec<_> = - izip!(&validators, &validator_keypairs) - .map(|(validator, validator_keypair)| { - // Each validator holds their own instance of DKG and creates their own aggregate - let dkg = Dkg::new( - TAU, - shares_num, - security_threshold, - &validators, - validator, - ) + let mut decryption_shares = selected_validators + .iter() + .map(|validator| { + let validator_keypair = validator_keypairs + .iter() + .find(|kp| kp.public_key() == validator.public_key) .unwrap(); - let aggregate = - dkg.aggregate_transcripts(messages).unwrap(); - assert!(pvss_aggregated - .verify(validators_num, messages) - .unwrap()); - - // And then each validator creates their own decryption share - aggregate - .create_decryption_share_precomputed( - &dkg, - &ciphertext.header().unwrap(), - AAD, - validator_keypair, - ) - .unwrap() - }) - .collect(); + // Each validator holds their own instance of DKG and creates their own aggregate + let dkg = Dkg::new( + TAU, + shares_num, + security_threshold, + &validators, + validator, + ) + .unwrap(); + let server_aggregate = + dkg.aggregate_transcripts(messages).unwrap(); + assert!(server_aggregate + .verify(validators_num, messages) + .unwrap()); + + // And then each validator creates their own decryption share + server_aggregate + .create_decryption_share_precomputed( + &dkg, + &ciphertext.header().unwrap(), + AAD, + validator_keypair, + &selected_validators, + ) + .unwrap() + }) + .collect::>(); decryption_shares.shuffle(rng); // Now, the decryption share can be used to decrypt the ciphertext @@ -715,10 +725,13 @@ mod test_ferveo_api { .unwrap(); assert_eq!(plaintext, MSG); - // Since we're using a precomputed variant, we need all the shares to be able to decrypt + // Since we're using a precomputed variant, we need `security_threshold` shares to be able to decrypt // So if we remove one share, we should not be able to decrypt - let decryption_shares = - decryption_shares[..shares_num as usize - 1].to_vec(); + let decryption_shares = decryption_shares + .iter() + .take(security_threshold as usize - 1) + .cloned() + .collect::>(); let shared_secret = share_combine_precomputed(&decryption_shares); let result = decrypt_with_shared_secret( &ciphertext, @@ -1240,8 +1253,8 @@ mod test_ferveo_api { .unwrap(); decryption_shares.push(new_decryption_share); domain_points.insert(new_validator_share_index, x_r); - assert_eq!(domain_points.len(), validators_num as usize); - assert_eq!(decryption_shares.len(), validators_num as usize); + // assert_eq!(domain_points.len(), validators_num as usize); + // assert_eq!(decryption_shares.len(), validators_num as usize); let domain_points = domain_points .values() diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index 0689a495..6a1a6b62 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -621,7 +621,10 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators: Vec, ) -> PyResult { + let selected_validators: Vec<_> = + selected_validators.into_iter().map(|v| v.0).collect(); let decryption_share = self .0 .create_decryption_share_precomputed( @@ -629,6 +632,7 @@ impl AggregatedTranscript { &ciphertext_header.0, aad, &validator_keypair.0, + &selected_validators, ) .map_err(FerveoPythonError::FerveoError)?; Ok(DecryptionSharePrecomputed(decryption_share)) @@ -866,18 +870,21 @@ mod test_ferveo_python { // Let's say that we've only received `security_threshold` transcripts let messages = messages[..security_threshold as usize].to_vec(); - let pvss_aggregated = + let local_aggregate = dkg.aggregate_transcripts(messages.clone()).unwrap(); - assert!(pvss_aggregated + assert!(local_aggregate .verify(validators_num, messages.clone()) .unwrap()); // At this point, any given validator should be able to provide a DKG public key - let dkg_public_key = pvss_aggregated.public_key(); + let dkg_public_key = local_aggregate.public_key(); // In the meantime, the client creates a ciphertext and decryption request let ciphertext = encrypt(MSG.to_vec(), AAD, &dkg_public_key).unwrap(); + // TODO: Adjust the subset of validators to be used in the decryption for precomputed + // variant + // Having aggregated the transcripts, the validators can now create decryption shares let decryption_shares: Vec<_> = izip!(validators.clone(), &validator_keypairs) @@ -891,18 +898,19 @@ mod test_ferveo_python { &validator, ) .unwrap(); - let aggregate = validator_dkg + let server_aggregate = validator_dkg .aggregate_transcripts(messages.clone()) .unwrap(); - assert!(pvss_aggregated + assert!(server_aggregate .verify(validators_num, messages.clone()) .is_ok()); - aggregate + server_aggregate .create_decryption_share_precomputed( &validator_dkg, &ciphertext.header().unwrap(), AAD, validator_keypair, + validators.clone(), ) .unwrap() }) diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index 56325092..0b369874 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -536,8 +536,15 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators_js: &ValidatorArray, ) -> JsResult { set_panic_hook(); + let selected_validators = + try_from_js_array::(selected_validators_js)?; + let selected_validators = selected_validators + .into_iter() + .map(|v| v.to_inner()) + .collect::>>()?; let decryption_share = self .0 .create_decryption_share_precomputed( @@ -545,6 +552,7 @@ impl AggregatedTranscript { &ciphertext_header.0, aad, &validator_keypair.0, + &selected_validators, ) .map_err(map_js_err)?; Ok(DecryptionSharePrecomputed(decryption_share)) diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index d2e825a7..360e5202 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -169,10 +169,10 @@ impl PubliclyVerifiableDkg { /// Return a map of domain points for the DKG pub fn domain_point_map(&self) -> HashMap> { - self.domain_points() - .iter() + self.domain + .elements() .enumerate() - .map(|(i, point)| (i as u32, *point)) + .map(|(i, point)| (i as u32, point)) .collect::>() } diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 67e501fe..cb73c176 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -231,20 +231,16 @@ mod test_dkg_full { // #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_dkg_simple_tdec_precomputed(shares_num: u32, validators_num: u32) { let rng = &mut test_rng(); - - // In precomputed variant, threshold must be equal to shares_num - let security_threshold = shares_num; + let security_threshold = shares_num * 2 / 3; let (dkg, validator_keypairs, messages) = - setup_dealt_dkg_with_n_validators( + setup_dealt_dkg_with_n_transcript_dealt( security_threshold, shares_num, validators_num, + shares_num, ); - let transcripts = messages - .iter() - .take(shares_num as usize) - .map(|m| m.1.clone()) - .collect::>(); + let transcripts = + messages.iter().map(|m| m.1.clone()).collect::>(); let pvss_aggregated = AggregatedTranscript::from_transcripts(&transcripts).unwrap(); assert!(pvss_aggregated @@ -260,8 +256,30 @@ mod test_dkg_full { ) .unwrap(); + // In precomputed variant, client selects a specific subset of validators to create + // decryption shares + let selected_keypairs = validator_keypairs + .choose_multiple(rng, security_threshold as usize) + .collect::>(); + let selected_validators = selected_keypairs + .iter() + .map(|keypair| { + dkg.get_validator(&keypair.public_key()) + .expect("Validator not found") + }) + .collect::>(); + // TODO: Move this logic into `create_decryption_share_precomputed`? + let selected_domain_points = selected_validators + .iter() + .filter_map(|v| { + dkg.get_domain_point(v.share_index) + .ok() + .map(|domain_point| (v.share_index, domain_point)) + }) + .collect::>>(); + let mut decryption_shares: Vec> = - validator_keypairs + selected_keypairs .iter() .map(|validator_keypair| { let validator = dkg @@ -269,20 +287,18 @@ mod test_dkg_full { .unwrap(); pvss_aggregated .aggregate - .create_decryption_share_simple_precomputed( + .create_decryption_share_precomputed( &ciphertext.header().unwrap(), AAD, validator_keypair, validator.share_index, - &dkg.domain_points(), + &selected_domain_points, ) .unwrap() }) - // We take only the first `security_threshold` decryption shares - .take(dkg.dkg_params.security_threshold() as usize) .collect(); - // Order of decryption shares is not important in the precomputed variant + // Order of decryption shares is not important decryption_shares.shuffle(rng); // Decrypt with precomputed variant diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index 700db9cb..8d1affeb 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -1,4 +1,4 @@ -use std::{hash::Hash, marker::PhantomData, ops::Mul}; +use std::{collections::HashMap, hash::Hash, marker::PhantomData, ops::Mul}; use ark_ec::{pairing::Pairing, AffineRepr, CurveGroup, Group}; use ark_ff::{Field, Zero}; @@ -358,16 +358,16 @@ impl PubliclyVerifiableSS { /// Make a decryption share (precomputed variant) for a given ciphertext /// With this method, we wrap the PrivateKeyShare method to avoid exposing the private key share // TODO: Consider deprecating to use PrivateKeyShare method directly - pub fn create_decryption_share_simple_precomputed( + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points: &HashMap>, ) -> Result> { self.decrypt_private_key_share(validator_keypair, share_index)? - .create_decryption_share_simple_precomputed( + .create_decryption_share_precomputed( ciphertext_header, aad, validator_keypair, diff --git a/ferveo/src/refresh.rs b/ferveo/src/refresh.rs index 0b8a95ef..d7700cfa 100644 --- a/ferveo/src/refresh.rs +++ b/ferveo/src/refresh.rs @@ -102,21 +102,51 @@ impl PrivateKeyShare { .map_err(|e| e.into()) } - pub fn create_decryption_share_simple_precomputed( + /// In precomputed variant, we offload some of the decryption related computation to the server-side: + /// We use the `prepare_combine_simple` function to precompute the lagrange coefficients + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points_map: &HashMap>, ) -> Result> { - let g_inv = PubliclyVerifiableParams::::default().g_inv(); - // In precomputed variant, we offload some of the decryption related computation to the server-side: - // We use the `prepare_combine_simple` function to precompute the lagrange coefficients - let lagrange_coeffs = prepare_combine_simple::(domain_points); - let lagrange_coeff = &lagrange_coeffs - .get(share_index as usize) + // We need to turn the domain points into a vector, and sort it by share index + let mut domain_points = domain_points_map + .iter() + .map(|(share_index, domain_point)| (*share_index, *domain_point)) + .collect::>(); + domain_points.sort_by_key(|(share_index, _)| *share_index); + + // Now, we have to pass the domain points to the `prepare_combine_simple` function + // and use the resulting lagrange coefficients to create the decryption share + + let only_domain_points = domain_points + .iter() + .map(|(_, domain_point)| *domain_point) + .collect::>(); + let lagrange_coeffs = prepare_combine_simple::(&only_domain_points); + + // Before we pick the lagrange coefficient for the current share index, we need + // to map the share index to the index in the domain points vector + // Given that we sorted the domain points by share index, the first element in the vector + // will correspond to the smallest share index, second to the second smallest, and so on + + let sorted_share_indices = domain_points + .iter() + .enumerate() + .map(|(adjusted_share_index, (share_index, _))| { + (*share_index, adjusted_share_index) + }) + .collect::>(); + let adjusted_share_index = *sorted_share_indices + .get(&share_index) .ok_or(Error::InvalidShareIndex(share_index))?; + + // Finally, pick the lagrange coefficient for the current share index + let lagrange_coeff = &lagrange_coeffs[adjusted_share_index]; + let g_inv = PubliclyVerifiableParams::::default().g_inv(); DecryptionSharePrecomputed::create( share_index as usize, &validator_keypair.decryption_key, @@ -368,8 +398,8 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, _, mut contexts) = setup_simple::( - security_threshold as usize, shares_num as usize, + security_threshold as usize, rng, ); @@ -447,8 +477,8 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, shared_private_key, mut contexts) = setup_simple::( - security_threshold as usize, shares_num as usize, + security_threshold as usize, rng, ); @@ -537,7 +567,7 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, private_key_share, contexts) = - setup_simple::(security_threshold, shares_num, rng); + setup_simple::(shares_num, security_threshold, rng); let domain_points = &contexts .iter() .map(|ctxt| {