From e6a7f6e55a34d892e664160f1f8cffe6e88c79da Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Mon, 29 Jan 2024 17:18:53 +0100
Subject: [PATCH] feature(dkg): prevent panics during transcript aggregation
---
ferveo/src/api.rs | 56 +++++++++++++++++++++++++++++++----
ferveo/src/bindings_python.rs | 10 +++++--
ferveo/src/bindings_wasm.rs | 3 +-
ferveo/src/dkg.rs | 2 +-
ferveo/src/lib.rs | 14 +++++----
ferveo/src/pvss.rs | 30 +++++++++----------
6 files changed, 85 insertions(+), 30 deletions(-)
diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs
index aab043d6..c8a9f6c3 100644
--- a/ferveo/src/api.rs
+++ b/ferveo/src/api.rs
@@ -246,7 +246,7 @@ impl Dkg {
for (validator, transcript) in messages {
self.0.deal(validator, transcript)?;
}
- Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss)))
+ Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss)?))
}
pub fn public_params(&self) -> DkgPublicParameters {
@@ -268,9 +268,9 @@ fn make_pvss_map(messages: &[ValidatorMessage]) -> PVSSMap {
pub struct AggregatedTranscript(PubliclyVerifiableSS);
impl AggregatedTranscript {
- pub fn new(messages: &[ValidatorMessage]) -> Self {
+ pub fn new(messages: &[ValidatorMessage]) -> Result {
let pvss_map = make_pvss_map(messages);
- AggregatedTranscript(crate::pvss::aggregate(&pvss_map))
+ Ok(AggregatedTranscript(crate::pvss::aggregate(&pvss_map)?))
}
pub fn verify(
@@ -625,6 +625,10 @@ mod test_ferveo_api {
assert!(result.is_err());
}
+ // Note that the server and client code are using the same underlying
+ // implementation for aggregation and aggregate verification.
+ // Here, we focus on testing user-facing APIs for server and client users.
+
#[test]
fn server_side_local_verification() {
let rng = &mut StdRng::seed_from_u64(0);
@@ -643,6 +647,41 @@ mod test_ferveo_api {
assert!(local_aggregate
.verify(dkg.0.dkg_params.shares_num(), &messages)
.is_ok());
+
+ // Test negative cases
+
+ // Notice that the dkg instance is mutable, so we need to get a fresh one
+ // for every test case
+
+ // Should fail if no transcripts are provided
+ let mut dkg =
+ Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
+ .unwrap();
+ let result = dkg.aggregate_transcripts(&[]);
+ assert!(result.is_err());
+
+ // Not enough transcripts
+ let mut dkg =
+ Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
+ .unwrap();
+ let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1];
+ assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize);
+ let insufficient_aggregate =
+ dkg.aggregate_transcripts(not_enough_messages).unwrap();
+ let result = insufficient_aggregate.verify(SHARES_NUM, &messages);
+ assert!(result.is_err());
+
+ // Unexpected transcripts in the aggregate or transcripts from a different ritual
+ // Using same DKG parameters, but different DKG instances and validators
+ let mut dkg =
+ Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
+ .unwrap();
+ let (bad_messages, _, _) =
+ make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM);
+ let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
+ let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap();
+ let result = bad_aggregate.verify(SHARES_NUM, &messages);
+ assert!(result.is_err());
}
#[test]
@@ -656,7 +695,8 @@ mod test_ferveo_api {
let messages = &messages[..SECURITY_THRESHOLD as usize];
// Create an aggregated transcript on the client side
- let aggregated_transcript = AggregatedTranscript::new(messages);
+ let aggregated_transcript =
+ AggregatedTranscript::new(messages).unwrap();
// We are separating the verification from the aggregation since the client may fetch
// the aggregate from a side-channel or decide to persist it and verify it later
@@ -668,11 +708,15 @@ mod test_ferveo_api {
// Test negative cases
+ // Should fail if no transcripts are provided
+ let result = AggregatedTranscript::new(&[]);
+ assert!(result.is_err());
+
// Not enough transcripts
let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1];
assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize);
let insufficient_aggregate =
- AggregatedTranscript::new(not_enough_messages);
+ AggregatedTranscript::new(not_enough_messages).unwrap();
let result = insufficient_aggregate.verify(SHARES_NUM, messages);
assert!(result.is_err());
@@ -681,7 +725,7 @@ mod test_ferveo_api {
let (bad_messages, _, _) =
make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM);
let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
- let bad_aggregate = AggregatedTranscript::new(&mixed_messages);
+ let bad_aggregate = AggregatedTranscript::new(&mixed_messages).unwrap();
let result = bad_aggregate.verify(SHARES_NUM, messages);
assert!(result.is_err());
}
diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs
index c35dc291..fab67561 100644
--- a/ferveo/src/bindings_python.rs
+++ b/ferveo/src/bindings_python.rs
@@ -113,6 +113,9 @@ impl From for PyErr {
"{index}"
))
},
+ Error::NoTranscriptsToAggregate => {
+ NoTranscriptsToAggregate::new_err("")
+ },
},
_ => default(),
}
@@ -149,6 +152,7 @@ create_exception!(exceptions, InvalidByteLength, PyValueError);
create_exception!(exceptions, InvalidVariant, PyValueError);
create_exception!(exceptions, InvalidDkgParameters, PyValueError);
create_exception!(exceptions, InvalidShareIndex, PyValueError);
+create_exception!(exceptions, NoTranscriptsToAggregate, PyValueError);
fn from_py_bytes(bytes: &[u8]) -> PyResult {
T::from_bytes(bytes)
@@ -580,10 +584,12 @@ generate_bytes_serialization!(AggregatedTranscript);
#[pymethods]
impl AggregatedTranscript {
#[new]
- pub fn new(messages: Vec) -> Self {
+ pub fn new(messages: Vec) -> PyResult {
let messages: Vec<_> =
messages.into_iter().map(|vm| vm.to_inner()).collect();
- Self(api::AggregatedTranscript::new(&messages))
+ let inner = api::AggregatedTranscript::new(&messages)
+ .map_err(FerveoPythonError::FerveoError)?;
+ Ok(Self(inner))
}
pub fn verify(
diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs
index 5a23909a..1396de13 100644
--- a/ferveo/src/bindings_wasm.rs
+++ b/ferveo/src/bindings_wasm.rs
@@ -507,7 +507,8 @@ impl AggregatedTranscript {
) -> JsResult {
set_panic_hook();
let messages = unwrap_messages_js(messages)?;
- let aggregated_transcript = api::AggregatedTranscript::new(&messages);
+ let aggregated_transcript =
+ api::AggregatedTranscript::new(&messages).unwrap();
Ok(AggregatedTranscript(aggregated_transcript))
}
diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs
index 10f73d7e..8259e5c3 100644
--- a/ferveo/src/dkg.rs
+++ b/ferveo/src/dkg.rs
@@ -179,7 +179,7 @@ impl PubliclyVerifiableDkg {
DkgState::Dealt => {
let public_key = self.public_key();
Ok(Message::Aggregate(Aggregation {
- vss: aggregate(&self.vss),
+ vss: aggregate(&self.vss)?,
public_key,
}))
}
diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs
index 0fca990c..f2eb357a 100644
--- a/ferveo/src/lib.rs
+++ b/ferveo/src/lib.rs
@@ -113,6 +113,10 @@ pub enum Error {
/// DKG may not contain duplicated share indices
#[error("Duplicated share index: {0}")]
DuplicatedShareIndex(u32),
+
+ /// Creating a transcript aggregate requires at least one transcript
+ #[error("No transcripts to aggregate")]
+ NoTranscriptsToAggregate,
}
pub type Result = std::result::Result;
@@ -148,7 +152,7 @@ mod test_dkg_full {
Vec>,
SharedSecret,
) {
- let pvss_aggregated = aggregate(&dkg.vss);
+ let pvss_aggregated = aggregate(&dkg.vss).unwrap();
assert!(pvss_aggregated.verify_aggregation(dkg).is_ok());
let decryption_shares: Vec> =
@@ -243,7 +247,7 @@ mod test_dkg_full {
)
.unwrap();
- let pvss_aggregated = aggregate(&dkg.vss);
+ let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated.verify_aggregation(&dkg).unwrap();
let domain_points = dkg
.domain
@@ -430,7 +434,7 @@ mod test_dkg_full {
// Creates updated private key shares
// TODO: Why not using dkg.aggregate()?
- let pvss_aggregated = aggregate(&dkg.vss);
+ let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.update_private_key_share_for_recovery(
&decryption_key,
@@ -461,7 +465,7 @@ mod test_dkg_full {
.enumerate()
.map(|(share_index, validator_keypair)| {
// TODO: Why not using dkg.aggregate()?
- let pvss_aggregated = aggregate(&dkg.vss);
+ let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.make_decryption_share_simple(
&ciphertext.header().unwrap(),
@@ -573,7 +577,7 @@ mod test_dkg_full {
// Creates updated private key shares
// TODO: Why not using dkg.aggregate()?
- let pvss_aggregated = aggregate(&dkg.vss);
+ let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.update_private_key_share_for_recovery(
&decryption_key,
diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs
index a5e458d7..60910158 100644
--- a/ferveo/src/pvss.rs
+++ b/ferveo/src/pvss.rs
@@ -381,13 +381,13 @@ impl PubliclyVerifiableSS {
/// Aggregate the PVSS instances in `pvss` from DKG session `dkg`
/// into a new PVSS instance
/// See: https://nikkolasg.github.io/ferveo/pvss.html?highlight=aggregate#aggregation
-pub fn aggregate(
+pub(crate) fn aggregate(
pvss_map: &PVSSMap,
-) -> PubliclyVerifiableSS {
- let mut pvss_iter = pvss_map.iter();
- let (_, first_pvss) = pvss_iter
+) -> Result> {
+ let mut pvss_iter = pvss_map.values();
+ let first_pvss = pvss_iter
.next()
- .expect("May not aggregate empty PVSS instances");
+ .ok_or_else(|| Error::NoTranscriptsToAggregate)?;
let mut coeffs = batch_to_projective_g1::(&first_pvss.coeffs);
let mut sigma = first_pvss.sigma;
@@ -396,25 +396,25 @@ pub fn aggregate(
// So now we're iterating over the PVSS instances, and adding their coefficients and shares, and their sigma
// sigma is the sum of all the sigma_i, which is the proof of knowledge of the secret polynomial
// Aggregating is just adding the corresponding values in pvss instances, so pvss = pvss + pvss_j
- for (_, next) in pvss_iter {
- sigma = (sigma + next.sigma).into();
+ for next_pvss in pvss_iter {
+ sigma = (sigma + next_pvss.sigma).into();
coeffs
.iter_mut()
- .zip_eq(next.coeffs.iter())
+ .zip_eq(next_pvss.coeffs.iter())
.for_each(|(a, b)| *a += b);
shares
.iter_mut()
- .zip_eq(next.shares.iter())
+ .zip_eq(next_pvss.shares.iter())
.for_each(|(a, b)| *a += b);
}
let shares = E::G2::normalize_batch(&shares);
- PubliclyVerifiableSS {
+ Ok(PubliclyVerifiableSS {
coeffs: E::G1::normalize_batch(&coeffs),
shares,
sigma,
phantom: Default::default(),
- }
+ })
}
#[cfg(test)]
@@ -526,7 +526,7 @@ mod test_pvss {
#[test]
fn test_aggregate_pvss() {
let (dkg, _) = setup_dealt_dkg();
- let aggregate = aggregate(&dkg.vss);
+ let aggregate = aggregate(&dkg.vss).unwrap();
// Check that a polynomial of the correct degree was created
assert_eq!(
aggregate.coeffs.len(),
@@ -542,15 +542,15 @@ mod test_pvss {
assert!(aggregate.verify_aggregation(&dkg).expect("Test failed"),);
}
- /// Check that if the aggregated pvss transcript has an
+ /// Check that if the aggregated PVSS transcript has an
/// incorrect constant term, the verification fails
#[test]
fn test_verify_aggregation_fails_if_constant_term_wrong() {
let (dkg, _) = setup_dealt_dkg();
- let mut aggregated = aggregate(&dkg.vss);
+ let mut aggregated = aggregate(&dkg.vss).unwrap();
while aggregated.coeffs[0] == G1::zero() {
let (dkg, _) = setup_dkg(0);
- aggregated = aggregate(&dkg.vss);
+ aggregated = aggregate(&dkg.vss).unwrap();
}
aggregated.coeffs[0] = G1::zero();
assert_eq!(