From e6a7f6e55a34d892e664160f1f8cffe6e88c79da Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Mon, 29 Jan 2024 17:18:53 +0100 Subject: [PATCH] feature(dkg): prevent panics during transcript aggregation --- ferveo/src/api.rs | 56 +++++++++++++++++++++++++++++++---- ferveo/src/bindings_python.rs | 10 +++++-- ferveo/src/bindings_wasm.rs | 3 +- ferveo/src/dkg.rs | 2 +- ferveo/src/lib.rs | 14 +++++---- ferveo/src/pvss.rs | 30 +++++++++---------- 6 files changed, 85 insertions(+), 30 deletions(-) diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index aab043d6..c8a9f6c3 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -246,7 +246,7 @@ impl Dkg { for (validator, transcript) in messages { self.0.deal(validator, transcript)?; } - Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss))) + Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss)?)) } pub fn public_params(&self) -> DkgPublicParameters { @@ -268,9 +268,9 @@ fn make_pvss_map(messages: &[ValidatorMessage]) -> PVSSMap { pub struct AggregatedTranscript(PubliclyVerifiableSS); impl AggregatedTranscript { - pub fn new(messages: &[ValidatorMessage]) -> Self { + pub fn new(messages: &[ValidatorMessage]) -> Result { let pvss_map = make_pvss_map(messages); - AggregatedTranscript(crate::pvss::aggregate(&pvss_map)) + Ok(AggregatedTranscript(crate::pvss::aggregate(&pvss_map)?)) } pub fn verify( @@ -625,6 +625,10 @@ mod test_ferveo_api { assert!(result.is_err()); } + // Note that the server and client code are using the same underlying + // implementation for aggregation and aggregate verification. + // Here, we focus on testing user-facing APIs for server and client users. + #[test] fn server_side_local_verification() { let rng = &mut StdRng::seed_from_u64(0); @@ -643,6 +647,41 @@ mod test_ferveo_api { assert!(local_aggregate .verify(dkg.0.dkg_params.shares_num(), &messages) .is_ok()); + + // Test negative cases + + // Notice that the dkg instance is mutable, so we need to get a fresh one + // for every test case + + // Should fail if no transcripts are provided + let mut dkg = + Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me) + .unwrap(); + let result = dkg.aggregate_transcripts(&[]); + assert!(result.is_err()); + + // Not enough transcripts + let mut dkg = + Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me) + .unwrap(); + let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1]; + assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize); + let insufficient_aggregate = + dkg.aggregate_transcripts(not_enough_messages).unwrap(); + let result = insufficient_aggregate.verify(SHARES_NUM, &messages); + assert!(result.is_err()); + + // Unexpected transcripts in the aggregate or transcripts from a different ritual + // Using same DKG parameters, but different DKG instances and validators + let mut dkg = + Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me) + .unwrap(); + let (bad_messages, _, _) = + make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM); + let mixed_messages = [&messages[..2], &bad_messages[..1]].concat(); + let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap(); + let result = bad_aggregate.verify(SHARES_NUM, &messages); + assert!(result.is_err()); } #[test] @@ -656,7 +695,8 @@ mod test_ferveo_api { let messages = &messages[..SECURITY_THRESHOLD as usize]; // Create an aggregated transcript on the client side - let aggregated_transcript = AggregatedTranscript::new(messages); + let aggregated_transcript = + AggregatedTranscript::new(messages).unwrap(); // We are separating the verification from the aggregation since the client may fetch // the aggregate from a side-channel or decide to persist it and verify it later @@ -668,11 +708,15 @@ mod test_ferveo_api { // Test negative cases + // Should fail if no transcripts are provided + let result = AggregatedTranscript::new(&[]); + assert!(result.is_err()); + // Not enough transcripts let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1]; assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize); let insufficient_aggregate = - AggregatedTranscript::new(not_enough_messages); + AggregatedTranscript::new(not_enough_messages).unwrap(); let result = insufficient_aggregate.verify(SHARES_NUM, messages); assert!(result.is_err()); @@ -681,7 +725,7 @@ mod test_ferveo_api { let (bad_messages, _, _) = make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM); let mixed_messages = [&messages[..2], &bad_messages[..1]].concat(); - let bad_aggregate = AggregatedTranscript::new(&mixed_messages); + let bad_aggregate = AggregatedTranscript::new(&mixed_messages).unwrap(); let result = bad_aggregate.verify(SHARES_NUM, messages); assert!(result.is_err()); } diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index c35dc291..fab67561 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -113,6 +113,9 @@ impl From for PyErr { "{index}" )) }, + Error::NoTranscriptsToAggregate => { + NoTranscriptsToAggregate::new_err("") + }, }, _ => default(), } @@ -149,6 +152,7 @@ create_exception!(exceptions, InvalidByteLength, PyValueError); create_exception!(exceptions, InvalidVariant, PyValueError); create_exception!(exceptions, InvalidDkgParameters, PyValueError); create_exception!(exceptions, InvalidShareIndex, PyValueError); +create_exception!(exceptions, NoTranscriptsToAggregate, PyValueError); fn from_py_bytes(bytes: &[u8]) -> PyResult { T::from_bytes(bytes) @@ -580,10 +584,12 @@ generate_bytes_serialization!(AggregatedTranscript); #[pymethods] impl AggregatedTranscript { #[new] - pub fn new(messages: Vec) -> Self { + pub fn new(messages: Vec) -> PyResult { let messages: Vec<_> = messages.into_iter().map(|vm| vm.to_inner()).collect(); - Self(api::AggregatedTranscript::new(&messages)) + let inner = api::AggregatedTranscript::new(&messages) + .map_err(FerveoPythonError::FerveoError)?; + Ok(Self(inner)) } pub fn verify( diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index 5a23909a..1396de13 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -507,7 +507,8 @@ impl AggregatedTranscript { ) -> JsResult { set_panic_hook(); let messages = unwrap_messages_js(messages)?; - let aggregated_transcript = api::AggregatedTranscript::new(&messages); + let aggregated_transcript = + api::AggregatedTranscript::new(&messages).unwrap(); Ok(AggregatedTranscript(aggregated_transcript)) } diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index 10f73d7e..8259e5c3 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -179,7 +179,7 @@ impl PubliclyVerifiableDkg { DkgState::Dealt => { let public_key = self.public_key(); Ok(Message::Aggregate(Aggregation { - vss: aggregate(&self.vss), + vss: aggregate(&self.vss)?, public_key, })) } diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 0fca990c..f2eb357a 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -113,6 +113,10 @@ pub enum Error { /// DKG may not contain duplicated share indices #[error("Duplicated share index: {0}")] DuplicatedShareIndex(u32), + + /// Creating a transcript aggregate requires at least one transcript + #[error("No transcripts to aggregate")] + NoTranscriptsToAggregate, } pub type Result = std::result::Result; @@ -148,7 +152,7 @@ mod test_dkg_full { Vec>, SharedSecret, ) { - let pvss_aggregated = aggregate(&dkg.vss); + let pvss_aggregated = aggregate(&dkg.vss).unwrap(); assert!(pvss_aggregated.verify_aggregation(dkg).is_ok()); let decryption_shares: Vec> = @@ -243,7 +247,7 @@ mod test_dkg_full { ) .unwrap(); - let pvss_aggregated = aggregate(&dkg.vss); + let pvss_aggregated = aggregate(&dkg.vss).unwrap(); pvss_aggregated.verify_aggregation(&dkg).unwrap(); let domain_points = dkg .domain @@ -430,7 +434,7 @@ mod test_dkg_full { // Creates updated private key shares // TODO: Why not using dkg.aggregate()? - let pvss_aggregated = aggregate(&dkg.vss); + let pvss_aggregated = aggregate(&dkg.vss).unwrap(); pvss_aggregated .update_private_key_share_for_recovery( &decryption_key, @@ -461,7 +465,7 @@ mod test_dkg_full { .enumerate() .map(|(share_index, validator_keypair)| { // TODO: Why not using dkg.aggregate()? - let pvss_aggregated = aggregate(&dkg.vss); + let pvss_aggregated = aggregate(&dkg.vss).unwrap(); pvss_aggregated .make_decryption_share_simple( &ciphertext.header().unwrap(), @@ -573,7 +577,7 @@ mod test_dkg_full { // Creates updated private key shares // TODO: Why not using dkg.aggregate()? - let pvss_aggregated = aggregate(&dkg.vss); + let pvss_aggregated = aggregate(&dkg.vss).unwrap(); pvss_aggregated .update_private_key_share_for_recovery( &decryption_key, diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index a5e458d7..60910158 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -381,13 +381,13 @@ impl PubliclyVerifiableSS { /// Aggregate the PVSS instances in `pvss` from DKG session `dkg` /// into a new PVSS instance /// See: https://nikkolasg.github.io/ferveo/pvss.html?highlight=aggregate#aggregation -pub fn aggregate( +pub(crate) fn aggregate( pvss_map: &PVSSMap, -) -> PubliclyVerifiableSS { - let mut pvss_iter = pvss_map.iter(); - let (_, first_pvss) = pvss_iter +) -> Result> { + let mut pvss_iter = pvss_map.values(); + let first_pvss = pvss_iter .next() - .expect("May not aggregate empty PVSS instances"); + .ok_or_else(|| Error::NoTranscriptsToAggregate)?; let mut coeffs = batch_to_projective_g1::(&first_pvss.coeffs); let mut sigma = first_pvss.sigma; @@ -396,25 +396,25 @@ pub fn aggregate( // So now we're iterating over the PVSS instances, and adding their coefficients and shares, and their sigma // sigma is the sum of all the sigma_i, which is the proof of knowledge of the secret polynomial // Aggregating is just adding the corresponding values in pvss instances, so pvss = pvss + pvss_j - for (_, next) in pvss_iter { - sigma = (sigma + next.sigma).into(); + for next_pvss in pvss_iter { + sigma = (sigma + next_pvss.sigma).into(); coeffs .iter_mut() - .zip_eq(next.coeffs.iter()) + .zip_eq(next_pvss.coeffs.iter()) .for_each(|(a, b)| *a += b); shares .iter_mut() - .zip_eq(next.shares.iter()) + .zip_eq(next_pvss.shares.iter()) .for_each(|(a, b)| *a += b); } let shares = E::G2::normalize_batch(&shares); - PubliclyVerifiableSS { + Ok(PubliclyVerifiableSS { coeffs: E::G1::normalize_batch(&coeffs), shares, sigma, phantom: Default::default(), - } + }) } #[cfg(test)] @@ -526,7 +526,7 @@ mod test_pvss { #[test] fn test_aggregate_pvss() { let (dkg, _) = setup_dealt_dkg(); - let aggregate = aggregate(&dkg.vss); + let aggregate = aggregate(&dkg.vss).unwrap(); // Check that a polynomial of the correct degree was created assert_eq!( aggregate.coeffs.len(), @@ -542,15 +542,15 @@ mod test_pvss { assert!(aggregate.verify_aggregation(&dkg).expect("Test failed"),); } - /// Check that if the aggregated pvss transcript has an + /// Check that if the aggregated PVSS transcript has an /// incorrect constant term, the verification fails #[test] fn test_verify_aggregation_fails_if_constant_term_wrong() { let (dkg, _) = setup_dealt_dkg(); - let mut aggregated = aggregate(&dkg.vss); + let mut aggregated = aggregate(&dkg.vss).unwrap(); while aggregated.coeffs[0] == G1::zero() { let (dkg, _) = setup_dkg(0); - aggregated = aggregate(&dkg.vss); + aggregated = aggregate(&dkg.vss).unwrap(); } aggregated.coeffs[0] = G1::zero(); assert_eq!(