From 66d25aecb5a3e29784f6d2ef1a7977ce4a2d406a Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Thu, 18 Jan 2024 19:05:50 +0100 Subject: [PATCH] refactor(test): use test_case crate to deduplicate tests --- Cargo.lock | 34 +++++ ferveo/Cargo.toml | 1 + ferveo/src/api.rs | 346 ++++++++++++++++++++++------------------------ ferveo/src/dkg.rs | 10 +- ferveo/src/lib.rs | 177 ++++++++++++------------ 5 files changed, 293 insertions(+), 275 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7aa7e272..685f931f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -786,6 +786,7 @@ dependencies = [ "serde", "serde_with", "subproductdomain-pre-release", + "test-case", "thiserror", "wasm-bindgen", "wasm-bindgen-derive", @@ -1881,6 +1882,39 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", + "test-case-core", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/ferveo/Cargo.toml b/ferveo/Cargo.toml index c063c502..d2a0d0eb 100644 --- a/ferveo/Cargo.toml +++ b/ferveo/Cargo.toml @@ -51,6 +51,7 @@ wasm-bindgen-derive = { version = "0.2.1", optional = true } criterion = "0.3" # supports pprof, # TODO: Figure out if/how we can update to 0.4 digest = { version = "0.10.0", features = ["alloc"] } pprof = { version = "0.6", features = ["flamegraph", "criterion"] } +test-case = "3.3.1" # WASM bindings console_error_panic_hook = "0.1.7" diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index be8f5b1f..0e610658 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -397,11 +397,14 @@ mod test_ferveo_api { use ferveo_tdec::SecretBox; use itertools::izip; use rand::{prelude::StdRng, SeedableRng}; + use test_case::test_case; use crate::{api::*, dkg::test_common::*}; type TestInputs = (Vec, Vec, Vec); + const TAU: u32 = 1; + fn make_test_inputs( rng: &mut StdRng, tau: u32, @@ -446,204 +449,188 @@ mod test_ferveo_api { assert_eq!(dkg_pk, deserialized); } - #[test] - fn test_server_api_tdec_precomputed() { + #[test_case(4; "number of shares (validators) is a power of 2")] + #[test_case(7; "number of shares (validators) is not a power of 2")] + fn test_server_api_tdec_precomputed(shares_num: u32) { let rng = &mut StdRng::seed_from_u64(0); - // Works for both power of 2 and non-power of 2 - for shares_num in [4, 7] { - let tau = 1; - // In precomputed variant, the security threshold is equal to the number of shares - // TODO: Refactor DKG constructor to not require security threshold or this case. - // Or figure out a different way to simplify the precomputed variant API. - let security_threshold = shares_num; - - let (messages, validators, validator_keypairs) = - make_test_inputs(rng, tau, security_threshold, shares_num); - - // Now that every validator holds a dkg instance and a transcript for every other validator, - // every validator can aggregate the transcripts - let me = validators[0].clone(); - let mut dkg = - Dkg::new(tau, shares_num, security_threshold, &validators, &me) - .unwrap(); - - let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); - - // At this point, any given validator should be able to provide a DKG public key - let dkg_public_key = dkg.public_key(); - - // In the meantime, the client creates a ciphertext and decryption request - let msg = "my-msg".as_bytes().to_vec(); - let aad: &[u8] = "my-aad".as_bytes(); - let ciphertext = - encrypt(SecretBox::new(msg.clone()), aad, &dkg_public_key) - .unwrap(); - - // Having aggregated the transcripts, the validators can now create decryption shares - let decryption_shares: Vec<_> = - izip!(&validators, &validator_keypairs) - .map(|(validator, validator_keypair)| { - // Each validator holds their own instance of DKG and creates their own aggregate - let mut dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - validator, - ) - .unwrap(); - let aggregate = - dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated - .verify(shares_num, &messages) - .unwrap()); - - // And then each validator creates their own decryption share - aggregate - .create_decryption_share_precomputed( - &dkg, - &ciphertext.header().unwrap(), - aad, - validator_keypair, - ) - .unwrap() - }) - .collect(); - - // Now, the decryption share can be used to decrypt the ciphertext - // This part is part of the client API - - let shared_secret = share_combine_precomputed(&decryption_shares); - let plaintext = decrypt_with_shared_secret( - &ciphertext, - aad, - &SharedSecret(shared_secret), - ) - .unwrap(); - assert_eq!(plaintext, msg); - - // Since we're using a precomputed variant, we need all the shares to be able to decrypt - // So if we remove one share, we should not be able to decrypt - let decryption_shares = - decryption_shares[..shares_num as usize - 1].to_vec(); - - let shared_secret = share_combine_precomputed(&decryption_shares); - let result = decrypt_with_shared_secret( - &ciphertext, - aad, - &SharedSecret(shared_secret), - ); - assert!(result.is_err()); - } + // In precomputed variant, the security threshold is equal to the number of shares + // TODO: Refactor DKG constructor to not require security threshold or this case. + // Or figure out a different way to simplify the precomputed variant API. + let security_threshold = shares_num; + + let (messages, validators, validator_keypairs) = + make_test_inputs(rng, TAU, security_threshold, shares_num); + + // Now that every validator holds a dkg instance and a transcript for every other validator, + // every validator can aggregate the transcripts + let me = validators[0].clone(); + let mut dkg = + Dkg::new(TAU, shares_num, security_threshold, &validators, &me) + .unwrap(); + + let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); + assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); + + // At this point, any given validator should be able to provide a DKG public key + let dkg_public_key = dkg.public_key(); + + // In the meantime, the client creates a ciphertext and decryption request + let msg = "my-msg".as_bytes().to_vec(); + let aad: &[u8] = "my-aad".as_bytes(); + let ciphertext = + encrypt(SecretBox::new(msg.clone()), aad, &dkg_public_key).unwrap(); + + // Having aggregated the transcripts, the validators can now create decryption shares + let decryption_shares: Vec<_> = izip!(&validators, &validator_keypairs) + .map(|(validator, validator_keypair)| { + // Each validator holds their own instance of DKG and creates their own aggregate + let mut dkg = Dkg::new( + TAU, + shares_num, + security_threshold, + &validators, + validator, + ) + .unwrap(); + let aggregate = dkg.aggregate_transcripts(&messages).unwrap(); + assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); + + // And then each validator creates their own decryption share + aggregate + .create_decryption_share_precomputed( + &dkg, + &ciphertext.header().unwrap(), + aad, + validator_keypair, + ) + .unwrap() + }) + .collect(); + + // Now, the decryption share can be used to decrypt the ciphertext + // This part is part of the client API + + let shared_secret = share_combine_precomputed(&decryption_shares); + let plaintext = decrypt_with_shared_secret( + &ciphertext, + aad, + &SharedSecret(shared_secret), + ) + .unwrap(); + assert_eq!(plaintext, msg); + + // Since we're using a precomputed variant, we need all the shares to be able to decrypt + // So if we remove one share, we should not be able to decrypt + let decryption_shares = + decryption_shares[..shares_num as usize - 1].to_vec(); + + let shared_secret = share_combine_precomputed(&decryption_shares); + let result = decrypt_with_shared_secret( + &ciphertext, + aad, + &SharedSecret(shared_secret), + ); + assert!(result.is_err()); } - #[test] - fn test_server_api_tdec_simple() { + #[test_case(4; "number of shares (validators) is a power of 2")] + #[test_case(7; "number of shares (validators) is not a power of 2")] + fn test_server_api_tdec_simple(shares_num: u32) { let rng = &mut StdRng::seed_from_u64(0); - // Works for both power of 2 and non-power of 2 - for shares_num in [4, 7] { - let tau = 1; - let security_threshold = shares_num / 2 + 1; - - let (messages, validators, validator_keypairs) = - make_test_inputs(rng, tau, security_threshold, shares_num); - - // Now that every validator holds a dkg instance and a transcript for every other validator, - // every validator can aggregate the transcripts - let mut dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - &validators[0], - ) - .unwrap(); - - let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); - - // At this point, any given validator should be able to provide a DKG public key - let public_key = dkg.public_key(); - - // In the meantime, the client creates a ciphertext and decryption request - let msg = "my-msg".as_bytes().to_vec(); - let aad: &[u8] = "my-aad".as_bytes(); - let ciphertext = - encrypt(SecretBox::new(msg.clone()), aad, &public_key).unwrap(); - - // Having aggregated the transcripts, the validators can now create decryption shares - let decryption_shares: Vec<_> = - izip!(&validators, &validator_keypairs) - .map(|(validator, validator_keypair)| { - // Each validator holds their own instance of DKG and creates their own aggregate - let mut dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - validator, - ) - .unwrap(); - let aggregate = - dkg.aggregate_transcripts(&messages).unwrap(); - assert!(aggregate - .verify(shares_num, &messages) - .unwrap()); - aggregate - .create_decryption_share_simple( - &dkg, - &ciphertext.header().unwrap(), - aad, - validator_keypair, - ) - .unwrap() - }) - .collect(); - - // Now, the decryption share can be used to decrypt the ciphertext - // This part is part of the client API - - // In simple variant, we only need `security_threshold` shares to be able to decrypt - let decryption_shares = - decryption_shares[..security_threshold as usize].to_vec(); - - let shared_secret = combine_shares_simple(&decryption_shares); - let plaintext = - decrypt_with_shared_secret(&ciphertext, aad, &shared_secret) - .unwrap(); - assert_eq!(plaintext, msg); - - // Let's say that we've only received `security_threshold - 1` shares - // In this case, we should not be able to decrypt - let decryption_shares = - decryption_shares[..security_threshold as usize - 1].to_vec(); - - let shared_secret = combine_shares_simple(&decryption_shares); - let result = - decrypt_with_shared_secret(&ciphertext, aad, &shared_secret); - assert!(result.is_err()); - } + let security_threshold = shares_num / 2 + 1; + + let (messages, validators, validator_keypairs) = + make_test_inputs(rng, TAU, security_threshold, shares_num); + + // Now that every validator holds a dkg instance and a transcript for every other validator, + // every validator can aggregate the transcripts + let mut dkg = Dkg::new( + TAU, + shares_num, + security_threshold, + &validators, + &validators[0], + ) + .unwrap(); + + let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); + assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); + + // At this point, any given validator should be able to provide a DKG public key + let public_key = dkg.public_key(); + + // In the meantime, the client creates a ciphertext and decryption request + let msg = "my-msg".as_bytes().to_vec(); + let aad: &[u8] = "my-aad".as_bytes(); + let ciphertext = + encrypt(SecretBox::new(msg.clone()), aad, &public_key).unwrap(); + + // Having aggregated the transcripts, the validators can now create decryption shares + let decryption_shares: Vec<_> = izip!(&validators, &validator_keypairs) + .map(|(validator, validator_keypair)| { + // Each validator holds their own instance of DKG and creates their own aggregate + let mut dkg = Dkg::new( + TAU, + shares_num, + security_threshold, + &validators, + validator, + ) + .unwrap(); + let aggregate = dkg.aggregate_transcripts(&messages).unwrap(); + assert!(aggregate.verify(shares_num, &messages).unwrap()); + aggregate + .create_decryption_share_simple( + &dkg, + &ciphertext.header().unwrap(), + aad, + validator_keypair, + ) + .unwrap() + }) + .collect(); + + // Now, the decryption share can be used to decrypt the ciphertext + // This part is part of the client API + + // In simple variant, we only need `security_threshold` shares to be able to decrypt + let decryption_shares = + decryption_shares[..security_threshold as usize].to_vec(); + + let shared_secret = combine_shares_simple(&decryption_shares); + let plaintext = + decrypt_with_shared_secret(&ciphertext, aad, &shared_secret) + .unwrap(); + assert_eq!(plaintext, msg); + + // Let's say that we've only received `security_threshold - 1` shares + // In this case, we should not be able to decrypt + let decryption_shares = + decryption_shares[..security_threshold as usize - 1].to_vec(); + + let shared_secret = combine_shares_simple(&decryption_shares); + let result = + decrypt_with_shared_secret(&ciphertext, aad, &shared_secret); + assert!(result.is_err()); } #[test] fn server_side_local_verification() { let rng = &mut StdRng::seed_from_u64(0); - let tau = 1; let security_threshold = 3; let shares_num = 4; let (messages, validators, _) = - make_test_inputs(rng, tau, security_threshold, shares_num); + make_test_inputs(rng, TAU, security_threshold, shares_num); // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts let me = validators[0].clone(); let mut dkg = - Dkg::new(tau, shares_num, security_threshold, &validators, &me) + Dkg::new(TAU, shares_num, security_threshold, &validators, &me) .unwrap(); let local_aggregate = dkg.aggregate_transcripts(&messages).unwrap(); @@ -656,12 +643,11 @@ mod test_ferveo_api { fn client_side_local_verification() { let rng = &mut StdRng::seed_from_u64(0); - let tau = 1; let security_threshold = 3; let shares_num = 4; let (messages, _, _) = - make_test_inputs(rng, tau, security_threshold, shares_num); + make_test_inputs(rng, TAU, security_threshold, shares_num); // We only need `security_threshold` transcripts to aggregate let messages = &messages[..security_threshold as usize]; @@ -690,7 +676,7 @@ mod test_ferveo_api { // Unexpected transcripts in the aggregate or transcripts from a different ritual // Using same DKG parameters, but different DKG instances and validators let (bad_messages, _, _) = - make_test_inputs(rng, tau, security_threshold, shares_num); + make_test_inputs(rng, TAU, security_threshold, shares_num); let mixed_messages = [&messages[..2], &bad_messages[..1]].concat(); let bad_aggregate = AggregatedTranscript::new(&mixed_messages); let result = bad_aggregate.verify(shares_num, messages); diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index 808afd40..dbab8e03 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -508,18 +508,18 @@ mod test_dealing { fn test_pvss_dealing() { let rng = &mut ark_std::test_rng(); + // Create a test DKG instance + let (mut dkg, _) = setup_dkg(0); + // Gather everyone's transcripts let mut messages = vec![]; - for i in 0..4 { - let (mut dkg, _) = setup_dkg(i); + for i in 0..dkg.dkg_params.shares_num() { + let (mut dkg, _) = setup_dkg(i as usize); let message = dkg.share(rng).unwrap(); let sender = dkg.me.validator.clone(); messages.push((sender, message)); } - // Create a test DKG instance - let (mut dkg, _) = setup_dkg(0); - let mut expected = 0u32; for (sender, pvss) in messages.iter() { // Check the verification passes diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 4a29f099..605dc0d7 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -137,6 +137,7 @@ mod test_dkg_full { SharedSecret, }; use itertools::izip; + use test_case::test_case; use super::*; use crate::dkg::test_common::*; @@ -195,107 +196,103 @@ mod test_dkg_full { (pvss_aggregated, decryption_shares, shared_secret) } - #[test] - fn test_dkg_simple_tdec() { + #[test_case(4; "number of shares (validators) is a power of 2")] + #[test_case(7; "number of shares (validators) is not a power of 2")] + fn test_dkg_simple_tdec(shares_num: u32) { let rng = &mut test_rng(); - // Works for both power of 2 and non-power of 2 - for shares_num in [4, 7] { - let threshold = shares_num / 2 + 1; - let (dkg, validator_keypairs) = - setup_dealt_dkg_with_n_validators(threshold, shares_num); - let msg = "my-msg".as_bytes().to_vec(); - let aad: &[u8] = "my-aad".as_bytes(); - let public_key = dkg.public_key(); - let ciphertext = ferveo_tdec::encrypt::( - SecretBox::new(msg.clone()), - aad, - &public_key, - rng, - ) - .unwrap(); + let threshold = shares_num / 2 + 1; + let (dkg, validator_keypairs) = + setup_dealt_dkg_with_n_validators(threshold, shares_num); + let msg = "my-msg".as_bytes().to_vec(); + let aad: &[u8] = "my-aad".as_bytes(); + let public_key = dkg.public_key(); + let ciphertext = ferveo_tdec::encrypt::( + SecretBox::new(msg.clone()), + aad, + &public_key, + rng, + ) + .unwrap(); - let (_, _, shared_secret) = make_shared_secret_simple_tdec( - &dkg, - aad, - &ciphertext.header().unwrap(), - validator_keypairs.as_slice(), - ); + let (_, _, shared_secret) = make_shared_secret_simple_tdec( + &dkg, + aad, + &ciphertext.header().unwrap(), + validator_keypairs.as_slice(), + ); - let plaintext = ferveo_tdec::decrypt_with_shared_secret( - &ciphertext, - aad, - &shared_secret, - &dkg.pvss_params.g_inv(), - ) - .unwrap(); - assert_eq!(plaintext, msg); - } + let plaintext = ferveo_tdec::decrypt_with_shared_secret( + &ciphertext, + aad, + &shared_secret, + &dkg.pvss_params.g_inv(), + ) + .unwrap(); + assert_eq!(plaintext, msg); } - #[test] - fn test_dkg_simple_tdec_precomputed() { + #[test_case(4; "number of shares (validators) is a power of 2")] + #[test_case(7; "number of shares (validators) is not a power of 2")] + fn test_dkg_simple_tdec_precomputed(shares_num: u32) { let rng = &mut test_rng(); - // Works for both power of 2 and non-power of 2 - for shares_num in [4, 7] { - // In precomputed variant, threshold must be equal to shares_num - let threshold = shares_num; - let (dkg, validator_keypairs) = - setup_dealt_dkg_with_n_validators(threshold, shares_num); - let msg = "my-msg".as_bytes().to_vec(); - let aad: &[u8] = "my-aad".as_bytes(); - let public_key = dkg.public_key(); - let ciphertext = ferveo_tdec::encrypt::( - SecretBox::new(msg.clone()), - aad, - &public_key, - rng, - ) - .unwrap(); + // In precomputed variant, threshold must be equal to shares_num + let threshold = shares_num; + let (dkg, validator_keypairs) = + setup_dealt_dkg_with_n_validators(threshold, shares_num); + let msg = "my-msg".as_bytes().to_vec(); + let aad: &[u8] = "my-aad".as_bytes(); + let public_key = dkg.public_key(); + let ciphertext = ferveo_tdec::encrypt::( + SecretBox::new(msg.clone()), + aad, + &public_key, + rng, + ) + .unwrap(); - let pvss_aggregated = aggregate(&dkg.vss); - pvss_aggregated.verify_aggregation(&dkg).unwrap(); - let domain_points = dkg - .domain - .elements() - .take(validator_keypairs.len()) - .collect::>(); - - let decryption_shares: Vec> = - validator_keypairs - .iter() - .map(|validator_keypair| { - let validator = dkg - .get_validator(&validator_keypair.public_key()) - .unwrap(); - pvss_aggregated - .make_decryption_share_simple_precomputed( - &ciphertext.header().unwrap(), - aad, - &validator_keypair.decryption_key, - validator.share_index, - &domain_points, - &dkg.pvss_params.g_inv(), - ) - .unwrap() - }) - .collect(); - assert_eq!(domain_points.len(), decryption_shares.len()); + let pvss_aggregated = aggregate(&dkg.vss); + pvss_aggregated.verify_aggregation(&dkg).unwrap(); + let domain_points = dkg + .domain + .elements() + .take(validator_keypairs.len()) + .collect::>(); + + let decryption_shares: Vec> = + validator_keypairs + .iter() + .map(|validator_keypair| { + let validator = dkg + .get_validator(&validator_keypair.public_key()) + .unwrap(); + pvss_aggregated + .make_decryption_share_simple_precomputed( + &ciphertext.header().unwrap(), + aad, + &validator_keypair.decryption_key, + validator.share_index, + &domain_points, + &dkg.pvss_params.g_inv(), + ) + .unwrap() + }) + .collect(); + assert_eq!(domain_points.len(), decryption_shares.len()); - let shared_secret = - ferveo_tdec::share_combine_precomputed::(&decryption_shares); + let shared_secret = + ferveo_tdec::share_combine_precomputed::(&decryption_shares); - // Combination works, let's decrypt - let plaintext = ferveo_tdec::decrypt_with_shared_secret( - &ciphertext, - aad, - &shared_secret, - &dkg.pvss_params.g_inv(), - ) - .unwrap(); - assert_eq!(plaintext, msg); - } + // Combination works, let's decrypt + let plaintext = ferveo_tdec::decrypt_with_shared_secret( + &ciphertext, + aad, + &shared_secret, + &dkg.pvss_params.g_inv(), + ) + .unwrap(); + assert_eq!(plaintext, msg); } #[test]