diff --git a/src/lib.rs b/src/lib.rs index ce6be6f..2549176 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,7 +20,7 @@ pub(crate) mod inc_encoding; pub mod serialization; pub mod signature; pub(crate) mod simd_utils; -pub(crate) mod symmetric; +pub mod symmetric; // Cached Poseidon2 permutations. // diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 983e57f..f7fd572 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -43,9 +43,9 @@ pub struct GeneralizedXMSSSignatureScheme< #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub struct GeneralizedXMSSSignature { - path: HashTreeOpening, - rho: IE::Randomness, - hashes: Vec, + pub path: HashTreeOpening, + pub rho: IE::Randomness, + pub hashes: Vec, } impl Encode for GeneralizedXMSSSignature { @@ -176,8 +176,8 @@ impl Decode for GeneralizedXMSSSign /// It contains a Merkle root and a parameter for the tweakable hash #[derive(Serialize, Deserialize)] pub struct GeneralizedXMSSPublicKey { - root: TH::Domain, - parameter: TH::Parameter, + pub root: TH::Domain, + pub parameter: TH::Parameter, } /// Secret key for GeneralizedXMSSSignatureScheme @@ -1003,7 +1003,7 @@ mod tests { use rand::rng; use ssz::{Decode, Encode}; - type TestTH = PoseidonTweakHash<5, 7, 2, 9, 155>; + type TestTH = PoseidonTweakHash<5, 7, 9, 155>; #[test] pub fn test_target_sum_poseidon() { @@ -1073,7 +1073,7 @@ mod tests { pub fn test_large_base_poseidon() { // Note: do not use these parameters, they are just for testing type PRF = ShakePRFtoF<4, 8>; - type TH = PoseidonTweakHash<4, 4, 2, 8, 32>; + type TH = PoseidonTweakHash<4, 4, 8, 32>; type MH = PoseidonMessageHash<4, 8, 8, 32, 256, 2, 9>; const TARGET_SUM: usize = 1 << 12; type IE = TargetSumEncoding; @@ -1090,7 +1090,7 @@ mod tests { pub fn test_large_dimension_poseidon() { // Note: do not use these parameters, they are just for testing type PRF = ShakePRFtoF<8, 8>; - type TH = PoseidonTweakHash<4, 8, 2, 8, 256>; + type TH = PoseidonTweakHash<4, 8, 8, 256>; type MH = PoseidonMessageHash<4, 8, 8, 256, 2, 2, 9>; const TARGET_SUM: usize = 128; type IE = TargetSumEncoding; diff --git a/src/signature/generalized_xmss/instantiations_poseidon.rs b/src/signature/generalized_xmss/instantiations_poseidon.rs index 5858864..905b182 100644 --- a/src/signature/generalized_xmss/instantiations_poseidon.rs +++ b/src/signature/generalized_xmss/instantiations_poseidon.rs @@ -33,8 +33,7 @@ pub mod lifetime_2_to_the_18 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw1 = - PoseidonTweakHash; + type THw1 = PoseidonTweakHash; type PRFw1 = ShakePRFtoF; type IEw1 = TargetSumEncoding; /// Instantiation with Lifetime 2^18, Target sum encoding, chunk size w = 1, @@ -59,8 +58,7 @@ pub mod lifetime_2_to_the_18 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw2 = - PoseidonTweakHash; + type THw2 = PoseidonTweakHash; type PRFw2 = ShakePRFtoF; type IEw2 = TargetSumEncoding; /// Instantiation with Lifetime 2^18, Target sum encoding, chunk size w = 2, @@ -85,8 +83,7 @@ pub mod lifetime_2_to_the_18 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw4 = - PoseidonTweakHash; + type THw4 = PoseidonTweakHash; type PRFw4 = ShakePRFtoF; type IEw4 = TargetSumEncoding; /// Instantiation with Lifetime 2^18, Target sum encoding, chunk size w = 4, @@ -111,8 +108,7 @@ pub mod lifetime_2_to_the_18 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw8 = - PoseidonTweakHash; + type THw8 = PoseidonTweakHash; type PRFw8 = ShakePRFtoF; type IEw8 = TargetSumEncoding; /// Instantiation with Lifetime 2^18, Target sum encoding, chunk size w = 8, @@ -260,8 +256,7 @@ pub mod lifetime_2_to_the_20 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw1 = - PoseidonTweakHash; + type THw1 = PoseidonTweakHash; type PRFw1 = ShakePRFtoF; type IEw1 = TargetSumEncoding; /// Instantiation with Lifetime 2^20, Target sum encoding, chunk size w = 1, @@ -286,8 +281,7 @@ pub mod lifetime_2_to_the_20 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw2 = - PoseidonTweakHash; + type THw2 = PoseidonTweakHash; type PRFw2 = ShakePRFtoF; type IEw2 = TargetSumEncoding; /// Instantiation with Lifetime 2^20, Target sum encoding, chunk size w = 2, @@ -312,8 +306,7 @@ pub mod lifetime_2_to_the_20 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw4 = - PoseidonTweakHash; + type THw4 = PoseidonTweakHash; type PRFw4 = ShakePRFtoF; type IEw4 = TargetSumEncoding; /// Instantiation with Lifetime 2^20, Target sum encoding, chunk size w = 4, @@ -339,8 +332,7 @@ pub mod lifetime_2_to_the_20 { TWEAK_LEN_FE, MSG_LEN_FE, >; - type THw8 = - PoseidonTweakHash; + type THw8 = PoseidonTweakHash; type PRFw8 = ShakePRFtoF; type IEw8 = TargetSumEncoding; /// Instantiation with Lifetime 2^20, Target sum encoding, chunk size w = 8, diff --git a/src/signature/generalized_xmss/instantiations_poseidon_top_level.rs b/src/signature/generalized_xmss/instantiations_poseidon_top_level.rs index bea2239..4065029 100644 --- a/src/signature/generalized_xmss/instantiations_poseidon_top_level.rs +++ b/src/signature/generalized_xmss/instantiations_poseidon_top_level.rs @@ -39,7 +39,7 @@ pub mod lifetime_2_to_the_18 { PARAMETER_LEN, RAND_LEN_FE, >; - type TH = PoseidonTweakHash; + type TH = PoseidonTweakHash; type PRF = ShakePRFtoF; type IE = TargetSumEncoding; @@ -86,7 +86,9 @@ pub mod lifetime_2_to_the_32 { use crate::{ inc_encoding::target_sum::TargetSumEncoding, - signature::generalized_xmss::GeneralizedXMSSSignatureScheme, + signature::generalized_xmss::{ + GeneralizedXMSSPublicKey, GeneralizedXMSSSignature, GeneralizedXMSSSignatureScheme, + }, symmetric::{ message_hash::top_level_poseidon::TopLevelPoseidonMessageHash, prf::shake_to_field::ShakePRFtoF, tweak_hash::poseidon::PoseidonTweakHash, @@ -103,7 +105,7 @@ pub mod lifetime_2_to_the_32 { const PARAMETER_LEN: usize = 5; const TWEAK_LEN_FE: usize = 2; const MSG_LEN_FE: usize = 9; - const RAND_LEN_FE: usize = 7; + pub const RAND_LEN_FE: usize = 7; const HASH_LEN_FE: usize = 8; const CAPACITY: usize = 9; @@ -112,7 +114,7 @@ pub mod lifetime_2_to_the_32 { const POS_INVOCATIONS: usize = 1; const POS_OUTPUT_LEN_FE: usize = POS_OUTPUT_LEN_PER_INV_FE * POS_INVOCATIONS; - type MH = TopLevelPoseidonMessageHash< + pub type MH = TopLevelPoseidonMessageHash< POS_OUTPUT_LEN_PER_INV_FE, POS_INVOCATIONS, POS_OUTPUT_LEN_FE, @@ -124,12 +126,14 @@ pub mod lifetime_2_to_the_32 { PARAMETER_LEN, RAND_LEN_FE, >; - type TH = PoseidonTweakHash; + type TH = PoseidonTweakHash; type PRF = ShakePRFtoF; type IE = TargetSumEncoding; pub type SIGTopLevelTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignatureScheme; + pub type PubKeyTopLevelTargetSumLifetime32Dim64Base8 = GeneralizedXMSSPublicKey; + pub type SigTopLevelTargetSumLifetime32Dim64Base8 = GeneralizedXMSSSignature; #[cfg(test)] mod test { @@ -205,7 +209,7 @@ pub mod lifetime_2_to_the_32 { PARAMETER_LEN, RAND_LEN_FE, >; - type TH = PoseidonTweakHash; + type TH = PoseidonTweakHash; type PRF = ShakePRFtoF; type IE = TargetSumEncoding; @@ -285,7 +289,7 @@ pub mod lifetime_2_to_the_32 { PARAMETER_LEN, RAND_LEN_FE, >; - type TH = PoseidonTweakHash; + type TH = PoseidonTweakHash; type PRF = ShakePRFtoF; type IE = TargetSumEncoding; @@ -368,7 +372,7 @@ pub mod lifetime_2_to_the_8 { PARAMETER_LEN, RAND_LEN_FE, >; - type TH = PoseidonTweakHash; + type TH = PoseidonTweakHash; type PRF = ShakePRFtoF; diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index bcd709f..7ce284d 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -21,6 +21,8 @@ const DOMAIN_PARAMETERS_LENGTH: usize = 4; const CHAIN_COMPRESSION_WIDTH: usize = 16; /// The state width for merging two hashes in a tree or for the sponge construction. const MERGE_COMPRESSION_WIDTH: usize = 24; +/// Number of field elements used to represent the tweak. +pub const TWEAK_LEN_FE: usize = 3; /// Enum to implement tweaks. pub enum PoseidonTweak { @@ -36,35 +38,41 @@ pub enum PoseidonTweak { } impl PoseidonTweak { - fn to_field_elements(&self) -> [F; TWEAK_LEN] { - // We first represent the entire tweak as one big integer - let mut acc = match self { + fn to_field_elements(&self) -> [F; TWEAK_LEN_FE] { + const _: () = assert!( + F::ORDER_U64 > 1 << 30, + "we need to store 30 bits in one field element" + ); + + match self { Self::TreeTweak { level, pos_in_level, } => { - ((*level as u128) << 40) - | ((*pos_in_level as u128) << 8) - | (TWEAK_SEPARATOR_FOR_TREE_HASH as u128) + // split pos_in_level (32 bits) into (30 bits, 2 bits) + [ + F::from_u32(pos_in_level & ((1 << 30) - 1)), + F::from_u32(pos_in_level >> 30), + F::from_u32(((*level as u32) << 8) | TWEAK_SEPARATOR_FOR_TREE_HASH as u32), + ] } Self::ChainTweak { epoch, chain_index, pos_in_chain, } => { - ((*epoch as u128) << 24) - | ((*chain_index as u128) << 16) - | ((*pos_in_chain as u128) << 8) - | (TWEAK_SEPARATOR_FOR_CHAIN_HASH as u128) + // split epoch (32 bits) into (30 bits, 2 bits) + [ + F::from_u32(epoch & ((1 << 30) - 1)), + F::from_u32(epoch >> 30), + F::from_u32( + ((*chain_index as u32) << 16) + | ((*pos_in_chain as u32) << 8) + | (TWEAK_SEPARATOR_FOR_CHAIN_HASH as u32), + ), + ] } - }; - - // Now we interpret this integer in base-p to get field elements - std::array::from_fn(|_| { - let digit = (acc % F::ORDER_U64 as u128) as u64; - acc /= F::ORDER_U64 as u128; - F::from_u64(digit) - }) + } } } @@ -251,7 +259,6 @@ where pub struct PoseidonTweakHash< const PARAMETER_LEN: usize, const HASH_LEN: usize, - const TWEAK_LEN: usize, const CAPACITY: usize, const NUM_CHUNKS: usize, >; @@ -259,10 +266,9 @@ pub struct PoseidonTweakHash< impl< const PARAMETER_LEN: usize, const HASH_LEN: usize, - const TWEAK_LEN: usize, const CAPACITY: usize, const NUM_CHUNKS: usize, -> TweakableHash for PoseidonTweakHash +> TweakableHash for PoseidonTweakHash { type Parameter = FieldArray; @@ -303,7 +309,7 @@ impl< // (2) hashing two siblings in the tree. We use compression mode. // (3) hashing a long vector of chain ends. We use sponge mode. - let tweak_fe = tweak.to_field_elements::(); + let tweak_fe = tweak.to_field_elements(); match message { [single] => { @@ -353,7 +359,7 @@ impl< let lengths: [u32; DOMAIN_PARAMETERS_LENGTH] = [ PARAMETER_LEN as u32, - TWEAK_LEN as u32, + TWEAK_LEN_FE as u32, NUM_CHUNKS as u32, HASH_LEN as u32, ]; @@ -420,7 +426,7 @@ impl< // This ensures different use cases produce different outputs. let lengths = [ PARAMETER_LEN as u32, - TWEAK_LEN as u32, + TWEAK_LEN_FE as u32, NUM_CHUNKS as u32, HASH_LEN as u32, ]; @@ -477,10 +483,10 @@ impl< // Generate tweaks for all epochs in this SIMD batch. // Each lane gets a tweak specific to its epoch. - let packed_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| { + let packed_tweak = array::from_fn::<_, TWEAK_LEN_FE, _>(|t_idx| { PackedF::from_fn(|lane| { Self::chain_tweak(epoch_chunk[lane], chain_index as u8, pos) - .to_field_elements::()[t_idx] + .to_field_elements()[t_idx] }) }); @@ -495,9 +501,9 @@ impl< current_pos += PARAMETER_LEN; // Copy tweak into the input buffer. - packed_input[current_pos..current_pos + TWEAK_LEN] + packed_input[current_pos..current_pos + TWEAK_LEN_FE] .copy_from_slice(&packed_tweak); - current_pos += TWEAK_LEN; + current_pos += TWEAK_LEN_FE; // Copy current chain value into the input buffer. packed_input[current_pos..current_pos + HASH_LEN] @@ -522,9 +528,9 @@ impl< // Generate tree tweaks for all epochs. // Level 0 indicates this is a bottom-layer leaf in the tree. - let packed_tree_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| { + let packed_tree_tweak = array::from_fn::<_, TWEAK_LEN_FE, _>(|t_idx| { PackedF::from_fn(|lane| { - Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::() + Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements() [t_idx] }) }); @@ -594,11 +600,11 @@ impl< "Poseidon Tweak Chain Hash: Capacity must be less than 24" ); assert!( - PARAMETER_LEN + TWEAK_LEN + HASH_LEN <= 16, + PARAMETER_LEN + TWEAK_LEN_FE + HASH_LEN <= 16, "Poseidon Tweak Chain Hash: Input lengths too large for Poseidon instance" ); assert!( - PARAMETER_LEN + TWEAK_LEN + 2 * HASH_LEN <= 24, + PARAMETER_LEN + TWEAK_LEN_FE + 2 * HASH_LEN <= 24, "Poseidon Tweak Tree Hash: Input lengths too large for Poseidon instance" ); @@ -611,7 +617,7 @@ impl< let bits_for_tree_tweak = f64::from(32 + 8_u32); let bits_for_chain_tweak = f64::from(32 + 8 + 8 + 8_u32); - let tweak_fe_bits = bits_per_fe * f64::from(TWEAK_LEN as u32); + let tweak_fe_bits = bits_per_fe * f64::from(TWEAK_LEN_FE as u32); assert!( tweak_fe_bits >= bits_for_tree_tweak, "Poseidon Tweak Hash: not enough field elements to encode the tree tweak" @@ -625,17 +631,16 @@ impl< // Example instantiations #[cfg(test)] -pub type PoseidonTweak44 = PoseidonTweakHash<4, 4, 3, 9, 128>; +pub type PoseidonTweak44 = PoseidonTweakHash<4, 4, 9, 128>; #[cfg(test)] -pub type PoseidonTweak37 = PoseidonTweakHash<3, 7, 3, 9, 128>; +pub type PoseidonTweak37 = PoseidonTweakHash<3, 7, 9, 128>; #[cfg(test)] -pub type PoseidonTweakW1L5 = PoseidonTweakHash<5, 7, 2, 9, 163>; +pub type PoseidonTweakW1L5 = PoseidonTweakHash<5, 7, 9, 163>; #[cfg(test)] mod tests { use std::collections::HashMap; - use num_bigint::BigUint; use rand::Rng; use super::*; @@ -745,19 +750,12 @@ mod tests { // Tweak let level = 1u8; let pos_in_level = 2u32; - let sep = TWEAK_SEPARATOR_FOR_TREE_HASH as u64; - - // Compute tweak_bigint - let tweak_bigint: BigUint = - (BigUint::from(level) << 40) + (BigUint::from(pos_in_level) << 8) + sep; - - // Use the field modulus - let p = BigUint::from(F::ORDER_U64); // Extract field elements in base-p let expected = [ - F::from_u128((&tweak_bigint % &p).try_into().unwrap()), - F::from_u128(((&tweak_bigint / &p) % &p).try_into().unwrap()), + F::from_u32(2), // pos_in_level & ((1 << 30) - 1) + F::from_u32(0), // pos_in_level >> 30 + F::from_u32((1u32 << 8) | TWEAK_SEPARATOR_FOR_TREE_HASH as u32), // (level << 8) | sep ]; // Check actual output @@ -765,7 +763,7 @@ mod tests { level, pos_in_level, }; - let computed = tweak.to_field_elements::<2>(); + let computed = tweak.to_field_elements(); assert_eq!(computed, expected); } @@ -775,21 +773,12 @@ mod tests { let epoch = 1u32; let chain_index = 2u8; let pos_in_chain = 3u8; - let sep = TWEAK_SEPARATOR_FOR_CHAIN_HASH as u64; - - // Compute tweak_bigint = (epoch << 24) + (chain_index << 16) + (pos_in_chain << 8) + sep - let tweak_bigint: BigUint = (BigUint::from(epoch) << 24) - + (BigUint::from(chain_index) << 16) - + (BigUint::from(pos_in_chain) << 8) - + sep; - - // Use the field modulus - let p = BigUint::from(F::ORDER_U64); // Extract field elements in base-p let expected = [ - F::from_u128((&tweak_bigint % &p).try_into().unwrap()), - F::from_u128(((&tweak_bigint / &p) % &p).try_into().unwrap()), + F::from_u32(1), // epoch & ((1 << 30) - 1) + F::from_u32(0), // epoch >> 30 + F::from_u32((2u32 << 16) | (3u32 << 8) | TWEAK_SEPARATOR_FOR_CHAIN_HASH as u32), // (chain_index << 16) | (pos_in_chain << 8) | sep ]; // Check actual output @@ -798,7 +787,7 @@ mod tests { chain_index, pos_in_chain, }; - let computed = tweak.to_field_elements::<2>(); + let computed = tweak.to_field_elements(); assert_eq!(computed, expected); } @@ -806,22 +795,18 @@ mod tests { fn test_tree_tweak_field_elements_max_values() { let level = u8::MAX; let pos_in_level = u32::MAX; - let sep = TWEAK_SEPARATOR_FOR_TREE_HASH as u64; - - let tweak_bigint: BigUint = - (BigUint::from(level) << 40) + (BigUint::from(pos_in_level) << 8) + sep; - let p = BigUint::from(F::ORDER_U64); let expected = [ - F::from_u128((&tweak_bigint % &p).try_into().unwrap()), - F::from_u128(((&tweak_bigint / &p) % &p).try_into().unwrap()), + F::from_u32((1 << 30) - 1), // pos_in_level & ((1 << 30) - 1) + F::from_u32(3), // pos_in_level >> 30 + F::from_u32((255u32 << 8) | TWEAK_SEPARATOR_FOR_TREE_HASH as u32), // (level << 8) | sep ]; let tweak = PoseidonTweak::TreeTweak { level, pos_in_level, }; - let computed = tweak.to_field_elements::<2>(); + let computed = tweak.to_field_elements(); assert_eq!(computed, expected); } @@ -830,17 +815,10 @@ mod tests { let epoch = u32::MAX; let chain_index = u8::MAX; let pos_in_chain = u8::MAX; - let sep = TWEAK_SEPARATOR_FOR_CHAIN_HASH as u64; - - let tweak_bigint: BigUint = (BigUint::from(epoch) << 24) - + (BigUint::from(chain_index) << 16) - + (BigUint::from(pos_in_chain) << 8) - + sep; - - let p = BigUint::from(F::ORDER_U64); let expected = [ - F::from_u128((&tweak_bigint % &p).try_into().unwrap()), - F::from_u128(((&tweak_bigint / &p) % &p).try_into().unwrap()), + F::from_u32((1 << 30) - 1), // epoch & ((1 << 30) - 1) + F::from_u32(3), // epoch >> 30 + F::from_u32((255u32 << 16) | (255u32 << 8) | TWEAK_SEPARATOR_FOR_CHAIN_HASH as u32), // (chain_index << 16) | (pos_in_chain << 8) | sep ]; let tweak = PoseidonTweak::ChainTweak { @@ -848,7 +826,7 @@ mod tests { chain_index, pos_in_chain, }; - let computed = tweak.to_field_elements::<2>(); + let computed = tweak.to_field_elements(); assert_eq!(computed, expected); } @@ -868,7 +846,7 @@ mod tests { level, pos_in_level, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some((prev_level, prev_pos_in_level)) = map.insert(tweak_encoding, (level, pos_in_level)) @@ -895,7 +873,7 @@ mod tests { level, pos_in_level, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_pos_in_level) = map.insert(tweak_encoding, pos_in_level) { assert_eq!( @@ -915,7 +893,7 @@ mod tests { level, pos_in_level, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_level) = map.insert(tweak_encoding, level) { assert_eq!( @@ -948,7 +926,7 @@ mod tests { chain_index, pos_in_chain, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_input) = map.insert(tweak_encoding, input) { assert_eq!( @@ -972,7 +950,7 @@ mod tests { chain_index, pos_in_chain, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_input) = map.insert(tweak_encoding, input) { assert_eq!( @@ -996,7 +974,7 @@ mod tests { chain_index, pos_in_chain, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_input) = map.insert(tweak_encoding, input) { assert_eq!( @@ -1020,7 +998,7 @@ mod tests { chain_index, pos_in_chain, } - .to_field_elements::<2>(); + .to_field_elements(); if let Some(prev_input) = map.insert(tweak_encoding, input) { assert_eq!( @@ -1037,7 +1015,6 @@ mod tests { PRF: Pseudorandom, const PARAMETER_LEN: usize, const HASH_LEN: usize, - const TWEAK_LEN: usize, const CAPACITY: usize, const NUM_CHUNKS: usize, >( @@ -1115,7 +1092,7 @@ mod tests { ); // Compute using naive/scalar implementation - let naive_result = compute_tree_leaves_naive::( + let naive_result = compute_tree_leaves_naive::( &prf_key, ¶meter, &epochs, @@ -1170,7 +1147,7 @@ mod tests { ); // Compute using naive/scalar implementation - let naive_result = compute_tree_leaves_naive::( + let naive_result = compute_tree_leaves_naive::( &prf_key, ¶meter, &random_epochs, @@ -1240,16 +1217,16 @@ mod tests { ) { // check encoding is deterministic let tweak1 = PoseidonTweak::ChainTweak { epoch: epoch1, chain_index, pos_in_chain }; - let result1 = tweak1.to_field_elements::<2>(); - let result2 = tweak1.to_field_elements::<2>(); + let result1 = tweak1.to_field_elements(); + let result2 = tweak1.to_field_elements(); prop_assert_eq!(result1, result2); // check output has correct length - prop_assert_eq!(result1.len(), 2); + prop_assert_eq!(result1.len(), TWEAK_LEN_FE); // check different epochs produce different encodings let tweak2 = PoseidonTweak::ChainTweak { epoch: epoch2, chain_index, pos_in_chain }; - let other = tweak2.to_field_elements::<2>(); + let other = tweak2.to_field_elements(); if epoch1 == epoch2 { prop_assert_eq!(result1, other); } else { @@ -1258,7 +1235,7 @@ mod tests { // check chain tweaks differ from tree tweaks (domain separation) let tree_tweak = PoseidonTweak::TreeTweak { level: 0, pos_in_level: epoch1 }; - let tree_result = tree_tweak.to_field_elements::<2>(); + let tree_result = tree_tweak.to_field_elements(); prop_assert_ne!(result1, tree_result); } @@ -1270,16 +1247,16 @@ mod tests { ) { // check encoding is deterministic let tweak1 = PoseidonTweak::TreeTweak { level: level1, pos_in_level }; - let result1 = tweak1.to_field_elements::<2>(); - let result2 = tweak1.to_field_elements::<2>(); + let result1 = tweak1.to_field_elements(); + let result2 = tweak1.to_field_elements(); prop_assert_eq!(result1, result2); // check output has correct length - prop_assert_eq!(result1.len(), 2); + prop_assert_eq!(result1.len(), TWEAK_LEN_FE); // check different levels produce different encodings let tweak2 = PoseidonTweak::TreeTweak { level: level2, pos_in_level }; - let other = tweak2.to_field_elements::<2>(); + let other = tweak2.to_field_elements(); if level1 == level2 { prop_assert_eq!(result1, other); } else { @@ -1287,4 +1264,53 @@ mod tests { } } } + + proptest! { + #[test] + fn proptest_tree_tweak_is_injective( + level in any::(), + pos_in_level in any::() + ) { + let tweak = PoseidonTweak::TreeTweak { level, pos_in_level }; + let encoded = tweak.to_field_elements(); + + // Inverse function: decode the field elements back to original parameters + let low_30_bits = encoded[0].as_canonical_u64() as u32; + let high_2_bits = encoded[1].as_canonical_u64() as u32; + let recovered_pos_in_level = low_30_bits | (high_2_bits << 30); + let third_elem = encoded[2].as_canonical_u64() as u32; + let recovered_level = (third_elem >> 8) as u8; + let recovered_separator = (third_elem & 0xFF) as u8; + + // Verify we recovered the original values + prop_assert_eq!(recovered_pos_in_level, pos_in_level); + prop_assert_eq!(recovered_level, level); + prop_assert_eq!(recovered_separator, TWEAK_SEPARATOR_FOR_TREE_HASH); + } + + #[test] + fn proptest_chain_tweak_is_injective( + epoch in any::(), + chain_index in any::(), + pos_in_chain in any::() + ) { + let tweak = PoseidonTweak::ChainTweak { epoch, chain_index, pos_in_chain }; + let encoded = tweak.to_field_elements(); + + // Inverse function: decode the field elements back to original parameters + let low_30_bits = encoded[0].as_canonical_u64() as u32; + let high_2_bits = encoded[1].as_canonical_u64() as u32; + let recovered_epoch = low_30_bits | (high_2_bits << 30); + let third_elem = encoded[2].as_canonical_u64() as u32; + let recovered_chain_index = (third_elem >> 16) as u8; + let recovered_pos_in_chain = ((third_elem >> 8) & 0xFF) as u8; + let recovered_separator = (third_elem & 0xFF) as u8; + + // Verify we recovered the original values + prop_assert_eq!(recovered_epoch, epoch); + prop_assert_eq!(recovered_chain_index, chain_index); + prop_assert_eq!(recovered_pos_in_chain, pos_in_chain); + prop_assert_eq!(recovered_separator, TWEAK_SEPARATOR_FOR_CHAIN_HASH); + } + } }