Skip to content

Commit

Permalink
feat(dkg): prevent panics during transcript aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Jan 29, 2024
1 parent d547a81 commit d402183
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 30 deletions.
56 changes: 50 additions & 6 deletions ferveo/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl Dkg {
for (validator, transcript) in messages {
self.0.deal(validator, transcript)?;
}
Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss)))
Ok(AggregatedTranscript(crate::pvss::aggregate(&self.0.vss)?))
}

pub fn public_params(&self) -> DkgPublicParameters {
Expand All @@ -268,9 +268,9 @@ fn make_pvss_map(messages: &[ValidatorMessage]) -> PVSSMap<E> {
pub struct AggregatedTranscript(PubliclyVerifiableSS<E, crate::Aggregated>);

impl AggregatedTranscript {
pub fn new(messages: &[ValidatorMessage]) -> Self {
pub fn new(messages: &[ValidatorMessage]) -> Result<Self> {
let pvss_map = make_pvss_map(messages);
AggregatedTranscript(crate::pvss::aggregate(&pvss_map))
Ok(AggregatedTranscript(crate::pvss::aggregate(&pvss_map)?))
}

pub fn verify(
Expand Down Expand Up @@ -625,6 +625,10 @@ mod test_ferveo_api {
assert!(result.is_err());
}

// Note that the server and client code are using the same underlying
// implementation for aggregation and aggregate verification.
// Here, we focus on testing user-facing APIs for server and client users.

#[test]
fn server_side_local_verification() {
let rng = &mut StdRng::seed_from_u64(0);
Expand All @@ -643,6 +647,41 @@ mod test_ferveo_api {
assert!(local_aggregate
.verify(dkg.0.dkg_params.shares_num(), &messages)
.is_ok());

// Test negative cases

// Notice that the dkg instance is mutable, so we need to get a fresh one
// for every test case

// Should fail if no transcripts are provided
let mut dkg =
Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
.unwrap();
let result = dkg.aggregate_transcripts(&[]);
assert!(result.is_err());

// Not enough transcripts
let mut dkg =
Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
.unwrap();
let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1];
assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize);
let insufficient_aggregate =
dkg.aggregate_transcripts(not_enough_messages).unwrap();
let result = insufficient_aggregate.verify(SHARES_NUM, &messages);
assert!(result.is_err());

// Unexpected transcripts in the aggregate or transcripts from a different ritual
// Using same DKG parameters, but different DKG instances and validators
let mut dkg =
Dkg::new(TAU, SHARES_NUM, SECURITY_THRESHOLD, &validators, &me)
.unwrap();
let (bad_messages, _, _) =
make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM);
let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap();
let result = bad_aggregate.verify(SHARES_NUM, &messages);
assert!(result.is_err());
}

#[test]
Expand All @@ -656,7 +695,8 @@ mod test_ferveo_api {
let messages = &messages[..SECURITY_THRESHOLD as usize];

// Create an aggregated transcript on the client side
let aggregated_transcript = AggregatedTranscript::new(messages);
let aggregated_transcript =
AggregatedTranscript::new(messages).unwrap();

// We are separating the verification from the aggregation since the client may fetch
// the aggregate from a side-channel or decide to persist it and verify it later
Expand All @@ -668,11 +708,15 @@ mod test_ferveo_api {

// Test negative cases

// Should fail if no transcripts are provided
let result = AggregatedTranscript::new(&[]);
assert!(result.is_err());

// Not enough transcripts
let not_enough_messages = &messages[..SECURITY_THRESHOLD as usize - 1];
assert!(not_enough_messages.len() < SECURITY_THRESHOLD as usize);
let insufficient_aggregate =
AggregatedTranscript::new(not_enough_messages);
AggregatedTranscript::new(not_enough_messages).unwrap();
let result = insufficient_aggregate.verify(SHARES_NUM, messages);
assert!(result.is_err());

Expand All @@ -681,7 +725,7 @@ mod test_ferveo_api {
let (bad_messages, _, _) =
make_test_inputs(rng, TAU, SECURITY_THRESHOLD, SHARES_NUM);
let mixed_messages = [&messages[..2], &bad_messages[..1]].concat();
let bad_aggregate = AggregatedTranscript::new(&mixed_messages);
let bad_aggregate = AggregatedTranscript::new(&mixed_messages).unwrap();
let result = bad_aggregate.verify(SHARES_NUM, messages);
assert!(result.is_err());
}
Expand Down
10 changes: 8 additions & 2 deletions ferveo/src/bindings_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ impl From<FerveoPythonError> for PyErr {
"{index}"
))
},
Error::NoTranscriptsToAggregate => {
NoTranscriptsToAggregate::new_err("")
},
},
_ => default(),
}
Expand Down Expand Up @@ -149,6 +152,7 @@ create_exception!(exceptions, InvalidByteLength, PyValueError);
create_exception!(exceptions, InvalidVariant, PyValueError);
create_exception!(exceptions, InvalidDkgParameters, PyValueError);
create_exception!(exceptions, InvalidShareIndex, PyValueError);
create_exception!(exceptions, NoTranscriptsToAggregate, PyValueError);

fn from_py_bytes<T: FromBytes>(bytes: &[u8]) -> PyResult<T> {
T::from_bytes(bytes)
Expand Down Expand Up @@ -580,10 +584,12 @@ generate_bytes_serialization!(AggregatedTranscript);
#[pymethods]
impl AggregatedTranscript {
#[new]
pub fn new(messages: Vec<ValidatorMessage>) -> Self {
pub fn new(messages: Vec<ValidatorMessage>) -> PyResult<Self> {
let messages: Vec<_> =
messages.into_iter().map(|vm| vm.to_inner()).collect();
Self(api::AggregatedTranscript::new(&messages))
let inner = api::AggregatedTranscript::new(&messages)
.map_err(FerveoPythonError::FerveoError)?;
Ok(Self(inner))
}

pub fn verify(
Expand Down
3 changes: 2 additions & 1 deletion ferveo/src/bindings_wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ impl AggregatedTranscript {
) -> JsResult<AggregatedTranscript> {
set_panic_hook();
let messages = unwrap_messages_js(messages)?;
let aggregated_transcript = api::AggregatedTranscript::new(&messages);
let aggregated_transcript =
api::AggregatedTranscript::new(&messages).unwrap();
Ok(AggregatedTranscript(aggregated_transcript))
}

Expand Down
2 changes: 1 addition & 1 deletion ferveo/src/dkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<E: Pairing> PubliclyVerifiableDkg<E> {
DkgState::Dealt => {
let public_key = self.public_key();
Ok(Message::Aggregate(Aggregation {
vss: aggregate(&self.vss),
vss: aggregate(&self.vss)?,
public_key,
}))
}
Expand Down
14 changes: 9 additions & 5 deletions ferveo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ pub enum Error {
/// DKG may not contain duplicated share indices
#[error("Duplicated share index: {0}")]
DuplicatedShareIndex(u32),

/// DKG instance contains no transcripts
#[error("No transcripts to aggregate")]
NoTranscriptsToAggregate,
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -148,7 +152,7 @@ mod test_dkg_full {
Vec<DecryptionShareSimple<E>>,
SharedSecret<E>,
) {
let pvss_aggregated = aggregate(&dkg.vss);
let pvss_aggregated = aggregate(&dkg.vss).unwrap();
assert!(pvss_aggregated.verify_aggregation(dkg).is_ok());

let decryption_shares: Vec<DecryptionShareSimple<E>> =
Expand Down Expand Up @@ -243,7 +247,7 @@ mod test_dkg_full {
)
.unwrap();

let pvss_aggregated = aggregate(&dkg.vss);
let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated.verify_aggregation(&dkg).unwrap();
let domain_points = dkg
.domain
Expand Down Expand Up @@ -430,7 +434,7 @@ mod test_dkg_full {

// Creates updated private key shares
// TODO: Why not using dkg.aggregate()?
let pvss_aggregated = aggregate(&dkg.vss);
let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.update_private_key_share_for_recovery(
&decryption_key,
Expand Down Expand Up @@ -461,7 +465,7 @@ mod test_dkg_full {
.enumerate()
.map(|(share_index, validator_keypair)| {
// TODO: Why not using dkg.aggregate()?
let pvss_aggregated = aggregate(&dkg.vss);
let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.make_decryption_share_simple(
&ciphertext.header().unwrap(),
Expand Down Expand Up @@ -573,7 +577,7 @@ mod test_dkg_full {

// Creates updated private key shares
// TODO: Why not using dkg.aggregate()?
let pvss_aggregated = aggregate(&dkg.vss);
let pvss_aggregated = aggregate(&dkg.vss).unwrap();
pvss_aggregated
.update_private_key_share_for_recovery(
&decryption_key,
Expand Down
30 changes: 15 additions & 15 deletions ferveo/src/pvss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ impl<E: Pairing, T: Aggregate> PubliclyVerifiableSS<E, T> {
/// Aggregate the PVSS instances in `pvss` from DKG session `dkg`
/// into a new PVSS instance
/// See: https://nikkolasg.github.io/ferveo/pvss.html?highlight=aggregate#aggregation
pub fn aggregate<E: Pairing>(
pub(crate) fn aggregate<E: Pairing>(
pvss_map: &PVSSMap<E>,
) -> PubliclyVerifiableSS<E, Aggregated> {
let mut pvss_iter = pvss_map.iter();
let (_, first_pvss) = pvss_iter
) -> Result<PubliclyVerifiableSS<E, Aggregated>> {
let mut pvss_iter = pvss_map.values();
let first_pvss = pvss_iter
.next()
.expect("May not aggregate empty PVSS instances");
.ok_or_else(|| Error::NoTranscriptsToAggregate)?;
let mut coeffs = batch_to_projective_g1::<E>(&first_pvss.coeffs);
let mut sigma = first_pvss.sigma;

Expand All @@ -396,25 +396,25 @@ pub fn aggregate<E: Pairing>(
// So now we're iterating over the PVSS instances, and adding their coefficients and shares, and their sigma
// sigma is the sum of all the sigma_i, which is the proof of knowledge of the secret polynomial
// Aggregating is just adding the corresponding values in pvss instances, so pvss = pvss + pvss_j
for (_, next) in pvss_iter {
sigma = (sigma + next.sigma).into();
for next_pvss in pvss_iter {
sigma = (sigma + next_pvss.sigma).into();
coeffs
.iter_mut()
.zip_eq(next.coeffs.iter())
.zip_eq(next_pvss.coeffs.iter())
.for_each(|(a, b)| *a += b);
shares
.iter_mut()
.zip_eq(next.shares.iter())
.zip_eq(next_pvss.shares.iter())
.for_each(|(a, b)| *a += b);
}
let shares = E::G2::normalize_batch(&shares);

PubliclyVerifiableSS {
Ok(PubliclyVerifiableSS {
coeffs: E::G1::normalize_batch(&coeffs),
shares,
sigma,
phantom: Default::default(),
}
})
}

#[cfg(test)]
Expand Down Expand Up @@ -526,7 +526,7 @@ mod test_pvss {
#[test]
fn test_aggregate_pvss() {
let (dkg, _) = setup_dealt_dkg();
let aggregate = aggregate(&dkg.vss);
let aggregate = aggregate(&dkg.vss).unwrap();
// Check that a polynomial of the correct degree was created
assert_eq!(
aggregate.coeffs.len(),
Expand All @@ -542,15 +542,15 @@ mod test_pvss {
assert!(aggregate.verify_aggregation(&dkg).expect("Test failed"),);
}

/// Check that if the aggregated pvss transcript has an
/// Check that if the aggregated PVSS transcript has an
/// incorrect constant term, the verification fails
#[test]
fn test_verify_aggregation_fails_if_constant_term_wrong() {
let (dkg, _) = setup_dealt_dkg();
let mut aggregated = aggregate(&dkg.vss);
let mut aggregated = aggregate(&dkg.vss).unwrap();
while aggregated.coeffs[0] == G1::zero() {
let (dkg, _) = setup_dkg(0);
aggregated = aggregate(&dkg.vss);
aggregated = aggregate(&dkg.vss).unwrap();
}
aggregated.coeffs[0] = G1::zero();
assert_eq!(
Expand Down

0 comments on commit d402183

Please sign in to comment.