diff --git a/src/inc_encoding/target_sum.rs b/src/inc_encoding/target_sum.rs index 54b0e15..0a735f5 100644 --- a/src/inc_encoding/target_sum.rs +++ b/src/inc_encoding/target_sum.rs @@ -86,3 +86,157 @@ impl IncomparableEncoding MH::internal_consistency_check(); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::F; + use crate::array::FieldArray; + use crate::symmetric::message_hash::MessageHash; + use crate::symmetric::message_hash::poseidon::PoseidonMessageHash445; + use p3_field::PrimeField32; + use proptest::prelude::*; + + const TEST_TARGET_SUM: usize = 115; + type TestTargetSumEncoding = TargetSumEncoding; + + #[test] + fn test_internal_consistency() { + TestTargetSumEncoding::internal_consistency_check(); + } + + #[test] + fn test_successful_encoding_fixed_message() { + // keep message fixed and only resample randomness + // this mirrors the actual signature scheme behavior + let mut rng = rand::rng(); + let parameter: FieldArray<4> = FieldArray(rng.random()); + let message: [u8; 32] = rng.random(); + let epoch = 0u32; + + // retry with different randomness until encoding succeeds + for _ in 0..1_000 { + let randomness = TestTargetSumEncoding::rand(&mut rng); + + if let Ok(chunks) = + TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch) + { + // check output has correct dimension + assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION); + + // check all chunks are in valid range [0, BASE-1] + for &chunk in &chunks { + assert!((chunk as usize) < TestTargetSumEncoding::BASE); + } + + // check sum equals target + let sum: usize = chunks.iter().map(|&x| x as usize).sum(); + assert_eq!(sum, TEST_TARGET_SUM); + + // check determinism: encoding again with same inputs produces same result + let result2 = + TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch); + assert_eq!(chunks, result2.unwrap()); + + return; + } + } + + panic!("failed to find successful encoding after 1000 attempts"); + } + + #[test] + fn test_successful_encoding_random_inputs() { + // retry with all random inputs until encoding succeeds + let mut rng = rand::rng(); + let epoch = 0u32; + + for _ in 0..1_000 { + let parameter: FieldArray<4> = FieldArray(rng.random()); + let message: [u8; 32] = rng.random(); + let randomness = TestTargetSumEncoding::rand(&mut rng); + + if let Ok(chunks) = + TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch) + { + // check output has correct dimension + assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION); + + // check all chunks are in valid range [0, BASE-1] + for &chunk in &chunks { + assert!((chunk as usize) < TestTargetSumEncoding::BASE); + } + + // check sum equals target + let sum: usize = chunks.iter().map(|&x| x as usize).sum(); + assert_eq!(sum, TEST_TARGET_SUM); + + // check determinism: encoding again with same inputs produces same result + let result2 = + TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch); + assert_eq!(chunks, result2.unwrap()); + + return; + } + } + + panic!("failed to find successful encoding after 1000 attempts"); + } + + proptest! { + #[test] + fn proptest_encoding_determinism_and_error_reporting( + message in prop::array::uniform32(any::()), + randomness_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + parameter_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + epoch in any::() + ) { + // build randomness and parameter from proptest values + let randomness_arr: [F; 4] = std::array::from_fn(|i| F::new(randomness_values[i])); + let randomness = FieldArray(randomness_arr); + let parameter_arr: [F; 4] = std::array::from_fn(|i| F::new(parameter_values[i])); + let parameter = FieldArray(parameter_arr); + + // compute expected sum from underlying message hash + let hash_chunks = PoseidonMessageHash445::apply(¶meter, epoch, &randomness, &message); + let hash_sum: usize = hash_chunks.iter().map(|&x| x as usize).sum(); + + // call encode twice to check determinism + let result1 = TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch); + let result2 = TestTargetSumEncoding::encode(¶meter, &message, &randomness, epoch); + + // check determinism: both calls produce same result + match (&result1, &result2) { + (Ok(c1), Ok(c2)) => prop_assert_eq!(c1, c2), + (Err(TargetSumError::Mismatch { expected: e1, actual: a1 }), + Err(TargetSumError::Mismatch { expected: e2, actual: a2 })) => { + prop_assert_eq!(e1, e2); + prop_assert_eq!(a1, a2); + } + _ => prop_assert!(false, "determinism violated"), + } + + // check properties based on success/failure + match result1 { + Err(TargetSumError::Mismatch { expected, actual }) => { + // check error reports correct values + prop_assert_eq!(expected, TEST_TARGET_SUM); + prop_assert_eq!(actual, hash_sum); + } + Ok(chunks) => { + // check output dimension + prop_assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION); + + // check all chunks in valid range + for &chunk in &chunks { + prop_assert!((chunk as usize) < TestTargetSumEncoding::BASE); + } + + // check sum equals target + let sum: usize = chunks.iter().map(|&x| x as usize).sum(); + prop_assert_eq!(sum, TEST_TARGET_SUM); + } + } + } + } +} diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 1ef1f39..983e57f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -994,6 +994,10 @@ mod tests { use super::*; + use crate::array::FieldArray; + use p3_field::PrimeField32; + use proptest::prelude::*; + use crate::{F, symmetric::tweak_hash::poseidon::PoseidonTweakHash}; use p3_field::RawDataSerializable; use rand::rng; @@ -1535,4 +1539,71 @@ mod tests { // Verify signature from decoded key validates assert!(Sig::verify(&pk, epoch + 1, &message, &sig2)); } + + proptest! { + #[test] + fn proptest_expand_activation_time_invariants( + desired_start in 0usize..256, + desired_duration in 1usize..256 + ) { + const LOG_LIFETIME: usize = 8; + const C: usize = 1 << (LOG_LIFETIME / 2); + const LIFETIME: usize = 1 << LOG_LIFETIME; + + let desired_end = (desired_start + desired_duration).min(LIFETIME); + + let (start, end) = expand_activation_time::(desired_start, desired_duration); + + let actual_start = start * C; + let actual_end = end * C; + + // check minimum duration of 2 bottom trees (each tree has C leaves) + prop_assert!(actual_end - actual_start >= 2 * C); + + // check result fits within lifetime + prop_assert!(actual_end <= LIFETIME); + + // check result contains the desired interval + prop_assert!(actual_start <= desired_start); + prop_assert!(actual_end >= desired_end); + + // check determinism by calling twice + let (start2, end2) = expand_activation_time::(desired_start, desired_duration); + prop_assert_eq!((start, end), (start2, end2)); + } + + #[test] + fn proptest_ssz_public_key_roundtrip_and_determinism( + root_values in prop::collection::vec(0u32..F::ORDER_U32, 7), + param_values in prop::collection::vec(0u32..F::ORDER_U32, 5) + ) { + // build public key from random field element values + let root_arr: [F; 7] = std::array::from_fn(|i| F::new(root_values[i])); + let param_arr: [F; 5] = std::array::from_fn(|i| F::new(param_values[i])); + + let original = GeneralizedXMSSPublicKey:: { + root: FieldArray(root_arr), + parameter: FieldArray(param_arr), + }; + + // encode to SSZ bytes + let encoded1 = original.as_ssz_bytes(); + let encoded2 = original.as_ssz_bytes(); + + // check encoding is deterministic + prop_assert_eq!(&encoded1, &encoded2); + + // check size matches expected (7 + 5 field elements * 4 bytes) + let expected_size = 12 * F::NUM_BYTES; + prop_assert_eq!(encoded1.len(), expected_size); + prop_assert_eq!(original.ssz_bytes_len(), expected_size); + + // decode and check roundtrip preserves data + let decoded = GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded1) + .expect("valid SSZ bytes should decode"); + + prop_assert_eq!(original.root, decoded.root); + prop_assert_eq!(original.parameter, decoded.parameter); + } + } } diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index 01e0503..fba524f 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -232,6 +232,8 @@ pub type PoseidonMessageHashW1 = PoseidonMessageHash<5, 5, 5, 163, 2, 2, 9>; mod tests { use super::*; use num_traits::Zero; + use p3_field::PrimeField32; + use proptest::prelude::*; use rand::Rng; use std::collections::HashMap; @@ -610,4 +612,100 @@ mod tests { "Reconstructed bigint from chunks does not match bigint from field elements" ); } + + proptest! { + #[test] + fn proptest_apply_determinism_and_output_validity( + message in prop::array::uniform32(any::()), + param_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + rand_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + epoch in any::() + ) { + // build parameter and randomness from proptest values + let param_arr: [F; 4] = std::array::from_fn(|i| F::new(param_values[i])); + let parameter = FieldArray(param_arr); + let rand_arr: [F; 4] = std::array::from_fn(|i| F::new(rand_values[i])); + let randomness = FieldArray(rand_arr); + + // call apply twice to check determinism + let result1 = PoseidonMessageHash445::apply(¶meter, epoch, &randomness, &message); + let result2 = PoseidonMessageHash445::apply(¶meter, epoch, &randomness, &message); + + // check determinism + prop_assert_eq!(&result1, &result2); + + // check output dimension + prop_assert_eq!(result1.len(), PoseidonMessageHash445::DIMENSION); + + // check all chunks are in valid range [0, BASE-1] + for &chunk in &result1 { + prop_assert!((chunk as usize) < PoseidonMessageHash445::BASE); + } + + // check different epochs produce different results + let other_epoch = PoseidonMessageHash445::apply( + ¶meter, + epoch.wrapping_add(1), + &randomness, + &message, + ); + prop_assert_ne!(&result1[..], &other_epoch[..]); + } + + #[test] + fn proptest_encode_epoch_properties( + epoch1 in any::(), + epoch2 in any::() + ) { + // check encoding is deterministic + let result1 = encode_epoch::<4>(epoch1); + let result2 = encode_epoch::<4>(epoch1); + prop_assert_eq!(result1, result2); + + // check output has correct length + prop_assert_eq!(result1.len(), 4); + + // check different epochs produce different encodings + let other = encode_epoch::<4>(epoch2); + if epoch1 == epoch2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + + // check zero epoch produces encoding with separator only (first element non-zero) + if epoch1 == 0 { + // epoch=0 should still produce non-trivial encoding due to separator + let has_nonzero = result1.iter().any(|&x| x != F::ZERO); + prop_assert!(has_nonzero); + } + } + + #[test] + fn proptest_encode_message_properties( + message1 in prop::array::uniform32(any::()), + message2 in prop::array::uniform32(any::()) + ) { + // check encoding is deterministic + let result1 = encode_message::<9>(&message1); + let result2 = encode_message::<9>(&message1); + prop_assert_eq!(result1, result2); + + // check output has correct length + prop_assert_eq!(result1.len(), 9); + + // check different messages produce different encodings + let other = encode_message::<9>(&message2); + if message1 == message2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + + // check zero message produces zero encoding + let zero_msg = [0u8; 32]; + let zero_result = encode_message::<9>(&zero_msg); + prop_assert!(zero_result.iter().all(|&x| x == F::ZERO)); + } + } } diff --git a/src/symmetric/prf/shake_to_field.rs b/src/symmetric/prf/shake_to_field.rs index eef91ed..d77563b 100644 --- a/src/symmetric/prf/shake_to_field.rs +++ b/src/symmetric/prf/shake_to_field.rs @@ -125,13 +125,16 @@ where #[cfg(test)] mod tests { use super::*; + use crate::MESSAGE_LENGTH; + use proptest::prelude::*; + + const DOMAIN_LEN: usize = 4; + const RAND_LEN: usize = 4; + type PRF = ShakePRFtoF; #[test] fn test_shake_to_field_prf_key_not_all_same() { const K: usize = 10; - const DOMAIN_LEN: usize = 4; - const RAND_LEN: usize = 4; - type PRF = ShakePRFtoF; let mut rng = rand::rng(); let mut all_same_count = 0; @@ -151,4 +154,65 @@ mod tests { K ); } + + proptest! { + #[test] + fn proptest_get_domain_element_properties( + key in prop::array::uniform32(any::()), + epoch in any::(), + index1 in any::(), + index2 in any::() + ) { + // check output has correct length + let result1 = PRF::get_domain_element(&key, epoch, index1); + prop_assert_eq!(result1.len(), DOMAIN_LEN); + + // check determinism: same inputs produce same output + let result2 = PRF::get_domain_element(&key, epoch, index1); + prop_assert_eq!(result1, result2); + + // check uniqueness: different indices produce different outputs + let other = PRF::get_domain_element(&key, epoch, index2); + if index1 == index2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + + // check different epochs produce different outputs + let other_epoch = PRF::get_domain_element(&key, epoch.wrapping_add(1), index1); + prop_assert_ne!(result1, other_epoch); + } + + #[test] + fn proptest_get_randomness_properties( + key in prop::array::uniform32(any::()), + epoch in any::(), + message in prop::array::uniform32(any::()), + counter1 in any::(), + counter2 in any::() + ) { + let msg: [u8; MESSAGE_LENGTH] = message; + + // check output has correct length + let result1 = PRF::get_randomness(&key, epoch, &msg, counter1); + prop_assert_eq!(result1.len(), RAND_LEN); + + // check determinism: same inputs produce same output + let result2 = PRF::get_randomness(&key, epoch, &msg, counter1); + prop_assert_eq!(result1, result2); + + // check uniqueness: different counters produce different outputs + let other = PRF::get_randomness(&key, epoch, &msg, counter2); + if counter1 == counter2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + + // check different epochs produce different outputs + let other_epoch = PRF::get_randomness(&key, epoch.wrapping_add(1), &msg, counter1); + prop_assert_ne!(result1, other_epoch); + } + } } diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 4171da6..bcd709f 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -638,9 +638,10 @@ mod tests { use num_bigint::BigUint; use rand::Rng; - use crate::symmetric::prf::shake_to_field::ShakePRFtoF; - use super::*; + use crate::symmetric::prf::shake_to_field::ShakePRFtoF; + use p3_field::PrimeField32; + use proptest::prelude::*; #[test] fn test_apply_44() { @@ -1193,4 +1194,97 @@ mod tests { ); } } + + proptest! { + #[test] + fn proptest_apply_properties( + param_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + msg_values in prop::collection::vec(0u32..F::ORDER_U32, 4), + epoch in any::(), + chain_index in any::(), + pos_in_chain in any::() + ) { + // build parameter and message from proptest values + let parameter = FieldArray(std::array::from_fn::<_, 4, _>(|i| F::new(param_values[i]))); + let message = FieldArray(std::array::from_fn::<_, 4, _>(|i| F::new(msg_values[i]))); + + // create chain tweak + let tweak = PoseidonTweak44::chain_tweak(epoch, chain_index, pos_in_chain); + + // call apply twice to check determinism + let result1 = PoseidonTweak44::apply(¶meter, &tweak, &[message]); + let result2 = PoseidonTweak44::apply(¶meter, &tweak, &[message]); + + // check determinism + prop_assert_eq!(result1, result2); + + // check output has correct length + prop_assert_eq!(result1.0.len(), 4); + + // check different tweaks produce different results + let other_tweak = PoseidonTweak44::chain_tweak( + epoch.wrapping_add(1), + chain_index, + pos_in_chain, + ); + let other_result = PoseidonTweak44::apply(¶meter, &other_tweak, &[message]); + prop_assert_ne!(result1, other_result); + } + + #[test] + fn proptest_chain_tweak_encoding_properties( + epoch1 in any::(), + epoch2 in any::(), + chain_index in any::(), + pos_in_chain in any::() + ) { + // check encoding is deterministic + let tweak1 = PoseidonTweak::ChainTweak { epoch: epoch1, chain_index, pos_in_chain }; + let result1 = tweak1.to_field_elements::<2>(); + let result2 = tweak1.to_field_elements::<2>(); + prop_assert_eq!(result1, result2); + + // check output has correct length + prop_assert_eq!(result1.len(), 2); + + // check different epochs produce different encodings + let tweak2 = PoseidonTweak::ChainTweak { epoch: epoch2, chain_index, pos_in_chain }; + let other = tweak2.to_field_elements::<2>(); + if epoch1 == epoch2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + + // check chain tweaks differ from tree tweaks (domain separation) + let tree_tweak = PoseidonTweak::TreeTweak { level: 0, pos_in_level: epoch1 }; + let tree_result = tree_tweak.to_field_elements::<2>(); + prop_assert_ne!(result1, tree_result); + } + + #[test] + fn proptest_tree_tweak_encoding_properties( + level1 in any::(), + level2 in any::(), + pos_in_level in any::() + ) { + // check encoding is deterministic + let tweak1 = PoseidonTweak::TreeTweak { level: level1, pos_in_level }; + let result1 = tweak1.to_field_elements::<2>(); + let result2 = tweak1.to_field_elements::<2>(); + prop_assert_eq!(result1, result2); + + // check output has correct length + prop_assert_eq!(result1.len(), 2); + + // check different levels produce different encodings + let tweak2 = PoseidonTweak::TreeTweak { level: level2, pos_in_level }; + let other = tweak2.to_field_elements::<2>(); + if level1 == level2 { + prop_assert_eq!(result1, other); + } else { + prop_assert_ne!(result1, other); + } + } + } }