Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/inc_encoding/target_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,157 @@ impl<MH: MessageHash, const TARGET_SUM: usize> IncomparableEncoding
MH::internal_consistency_check();
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::F;
use crate::array::FieldArray;
use crate::symmetric::message_hash::MessageHash;
use crate::symmetric::message_hash::poseidon::PoseidonMessageHash445;
use p3_field::PrimeField32;
use proptest::prelude::*;

const TEST_TARGET_SUM: usize = 115;
type TestTargetSumEncoding = TargetSumEncoding<PoseidonMessageHash445, TEST_TARGET_SUM>;

#[test]
fn test_internal_consistency() {
TestTargetSumEncoding::internal_consistency_check();
}

#[test]
fn test_successful_encoding_fixed_message() {
// keep message fixed and only resample randomness
// this mirrors the actual signature scheme behavior
let mut rng = rand::rng();
let parameter: FieldArray<4> = FieldArray(rng.random());
let message: [u8; 32] = rng.random();
let epoch = 0u32;

// retry with different randomness until encoding succeeds
for _ in 0..1_000 {
let randomness = TestTargetSumEncoding::rand(&mut rng);

if let Ok(chunks) =
TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch)
{
// check output has correct dimension
assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION);

// check all chunks are in valid range [0, BASE-1]
for &chunk in &chunks {
assert!((chunk as usize) < TestTargetSumEncoding::BASE);
}

// check sum equals target
let sum: usize = chunks.iter().map(|&x| x as usize).sum();
assert_eq!(sum, TEST_TARGET_SUM);

// check determinism: encoding again with same inputs produces same result
let result2 =
TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch);
assert_eq!(chunks, result2.unwrap());

return;
}
}

panic!("failed to find successful encoding after 1000 attempts");
}

#[test]
fn test_successful_encoding_random_inputs() {
// retry with all random inputs until encoding succeeds
let mut rng = rand::rng();
let epoch = 0u32;

for _ in 0..1_000 {
let parameter: FieldArray<4> = FieldArray(rng.random());
let message: [u8; 32] = rng.random();
let randomness = TestTargetSumEncoding::rand(&mut rng);

if let Ok(chunks) =
TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch)
{
// check output has correct dimension
assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION);

// check all chunks are in valid range [0, BASE-1]
for &chunk in &chunks {
assert!((chunk as usize) < TestTargetSumEncoding::BASE);
}

// check sum equals target
let sum: usize = chunks.iter().map(|&x| x as usize).sum();
assert_eq!(sum, TEST_TARGET_SUM);

// check determinism: encoding again with same inputs produces same result
let result2 =
TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch);
assert_eq!(chunks, result2.unwrap());

return;
}
}

panic!("failed to find successful encoding after 1000 attempts");
}

proptest! {
#[test]
fn proptest_encoding_determinism_and_error_reporting(
message in prop::array::uniform32(any::<u8>()),
randomness_values in prop::collection::vec(0u32..F::ORDER_U32, 4),
parameter_values in prop::collection::vec(0u32..F::ORDER_U32, 4),
epoch in any::<u32>()
) {
// build randomness and parameter from proptest values
let randomness_arr: [F; 4] = std::array::from_fn(|i| F::new(randomness_values[i]));
let randomness = FieldArray(randomness_arr);
let parameter_arr: [F; 4] = std::array::from_fn(|i| F::new(parameter_values[i]));
let parameter = FieldArray(parameter_arr);

// compute expected sum from underlying message hash
let hash_chunks = PoseidonMessageHash445::apply(&parameter, epoch, &randomness, &message);
let hash_sum: usize = hash_chunks.iter().map(|&x| x as usize).sum();

// call encode twice to check determinism
let result1 = TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch);
let result2 = TestTargetSumEncoding::encode(&parameter, &message, &randomness, epoch);

// check determinism: both calls produce same result
match (&result1, &result2) {
(Ok(c1), Ok(c2)) => prop_assert_eq!(c1, c2),
(Err(TargetSumError::Mismatch { expected: e1, actual: a1 }),
Err(TargetSumError::Mismatch { expected: e2, actual: a2 })) => {
prop_assert_eq!(e1, e2);
prop_assert_eq!(a1, a2);
}
_ => prop_assert!(false, "determinism violated"),
}

// check properties based on success/failure
match result1 {
Err(TargetSumError::Mismatch { expected, actual }) => {
// check error reports correct values
prop_assert_eq!(expected, TEST_TARGET_SUM);
prop_assert_eq!(actual, hash_sum);
}
Ok(chunks) => {
// check output dimension
prop_assert_eq!(chunks.len(), TestTargetSumEncoding::DIMENSION);

// check all chunks in valid range
for &chunk in &chunks {
prop_assert!((chunk as usize) < TestTargetSumEncoding::BASE);
}

// check sum equals target
let sum: usize = chunks.iter().map(|&x| x as usize).sum();
prop_assert_eq!(sum, TEST_TARGET_SUM);
}
}
}
}
}
71 changes: 71 additions & 0 deletions src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,10 @@ mod tests {

use super::*;

use crate::array::FieldArray;
use p3_field::PrimeField32;
use proptest::prelude::*;

use crate::{F, symmetric::tweak_hash::poseidon::PoseidonTweakHash};
use p3_field::RawDataSerializable;
use rand::rng;
Expand Down Expand Up @@ -1535,4 +1539,71 @@ mod tests {
// Verify signature from decoded key validates
assert!(Sig::verify(&pk, epoch + 1, &message, &sig2));
}

proptest! {
#[test]
fn proptest_expand_activation_time_invariants(
desired_start in 0usize..256,
desired_duration in 1usize..256
) {
const LOG_LIFETIME: usize = 8;
const C: usize = 1 << (LOG_LIFETIME / 2);
const LIFETIME: usize = 1 << LOG_LIFETIME;

let desired_end = (desired_start + desired_duration).min(LIFETIME);

let (start, end) = expand_activation_time::<LOG_LIFETIME>(desired_start, desired_duration);

let actual_start = start * C;
let actual_end = end * C;

// check minimum duration of 2 bottom trees (each tree has C leaves)
prop_assert!(actual_end - actual_start >= 2 * C);

// check result fits within lifetime
prop_assert!(actual_end <= LIFETIME);

// check result contains the desired interval
prop_assert!(actual_start <= desired_start);
prop_assert!(actual_end >= desired_end);

// check determinism by calling twice
let (start2, end2) = expand_activation_time::<LOG_LIFETIME>(desired_start, desired_duration);
prop_assert_eq!((start, end), (start2, end2));
}

#[test]
fn proptest_ssz_public_key_roundtrip_and_determinism(
root_values in prop::collection::vec(0u32..F::ORDER_U32, 7),
param_values in prop::collection::vec(0u32..F::ORDER_U32, 5)
) {
// build public key from random field element values
let root_arr: [F; 7] = std::array::from_fn(|i| F::new(root_values[i]));
let param_arr: [F; 5] = std::array::from_fn(|i| F::new(param_values[i]));

let original = GeneralizedXMSSPublicKey::<TestTH> {
root: FieldArray(root_arr),
parameter: FieldArray(param_arr),
};

// encode to SSZ bytes
let encoded1 = original.as_ssz_bytes();
let encoded2 = original.as_ssz_bytes();

// check encoding is deterministic
prop_assert_eq!(&encoded1, &encoded2);

// check size matches expected (7 + 5 field elements * 4 bytes)
let expected_size = 12 * F::NUM_BYTES;
prop_assert_eq!(encoded1.len(), expected_size);
prop_assert_eq!(original.ssz_bytes_len(), expected_size);

// decode and check roundtrip preserves data
let decoded = GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&encoded1)
.expect("valid SSZ bytes should decode");

prop_assert_eq!(original.root, decoded.root);
prop_assert_eq!(original.parameter, decoded.parameter);
}
}
}
98 changes: 98 additions & 0 deletions src/symmetric/message_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ pub type PoseidonMessageHashW1 = PoseidonMessageHash<5, 5, 5, 163, 2, 2, 9>;
mod tests {
use super::*;
use num_traits::Zero;
use p3_field::PrimeField32;
use proptest::prelude::*;
use rand::Rng;
use std::collections::HashMap;

Expand Down Expand Up @@ -610,4 +612,100 @@ mod tests {
"Reconstructed bigint from chunks does not match bigint from field elements"
);
}

proptest! {
#[test]
fn proptest_apply_determinism_and_output_validity(
message in prop::array::uniform32(any::<u8>()),
param_values in prop::collection::vec(0u32..F::ORDER_U32, 4),
rand_values in prop::collection::vec(0u32..F::ORDER_U32, 4),
epoch in any::<u32>()
) {
// build parameter and randomness from proptest values
let param_arr: [F; 4] = std::array::from_fn(|i| F::new(param_values[i]));
let parameter = FieldArray(param_arr);
let rand_arr: [F; 4] = std::array::from_fn(|i| F::new(rand_values[i]));
let randomness = FieldArray(rand_arr);

// call apply twice to check determinism
let result1 = PoseidonMessageHash445::apply(&parameter, epoch, &randomness, &message);
let result2 = PoseidonMessageHash445::apply(&parameter, epoch, &randomness, &message);

// check determinism
prop_assert_eq!(&result1, &result2);

// check output dimension
prop_assert_eq!(result1.len(), PoseidonMessageHash445::DIMENSION);

// check all chunks are in valid range [0, BASE-1]
for &chunk in &result1 {
prop_assert!((chunk as usize) < PoseidonMessageHash445::BASE);
}

// check different epochs produce different results
let other_epoch = PoseidonMessageHash445::apply(
&parameter,
epoch.wrapping_add(1),
&randomness,
&message,
);
prop_assert_ne!(&result1[..], &other_epoch[..]);
}

#[test]
fn proptest_encode_epoch_properties(
epoch1 in any::<u32>(),
epoch2 in any::<u32>()
) {
// check encoding is deterministic
let result1 = encode_epoch::<4>(epoch1);
let result2 = encode_epoch::<4>(epoch1);
prop_assert_eq!(result1, result2);

// check output has correct length
prop_assert_eq!(result1.len(), 4);

// check different epochs produce different encodings
let other = encode_epoch::<4>(epoch2);
if epoch1 == epoch2 {
prop_assert_eq!(result1, other);
} else {
prop_assert_ne!(result1, other);
}

// check zero epoch produces encoding with separator only (first element non-zero)
if epoch1 == 0 {
// epoch=0 should still produce non-trivial encoding due to separator
let has_nonzero = result1.iter().any(|&x| x != F::ZERO);
prop_assert!(has_nonzero);
}
}

#[test]
fn proptest_encode_message_properties(
message1 in prop::array::uniform32(any::<u8>()),
message2 in prop::array::uniform32(any::<u8>())
) {
// check encoding is deterministic
let result1 = encode_message::<9>(&message1);
let result2 = encode_message::<9>(&message1);
prop_assert_eq!(result1, result2);

// check output has correct length
prop_assert_eq!(result1.len(), 9);

// check different messages produce different encodings
let other = encode_message::<9>(&message2);
if message1 == message2 {
prop_assert_eq!(result1, other);
} else {
prop_assert_ne!(result1, other);
}

// check zero message produces zero encoding
let zero_msg = [0u8; 32];
let zero_result = encode_message::<9>(&zero_msg);
prop_assert!(zero_result.iter().all(|&x| x == F::ZERO));
}
}
}
Loading