From 1a23b5ccb23a87f8299bebc3c289f93f2e531186 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 17 Nov 2025 19:04:22 +0100 Subject: [PATCH 1/9] simd: apply packing for tree leaves --- Cargo.toml | 8 +- src/lib.rs | 2 + src/signature/generalized_xmss.rs | 37 +- src/symmetric/message_hash/poseidon.rs | 2 +- .../message_hash/top_level_poseidon.rs | 2 +- src/symmetric/tweak_hash.rs | 55 ++- src/symmetric/tweak_hash/poseidon.rs | 364 ++++++++++++++++-- 7 files changed, 397 insertions(+), 73 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e17aa60..0aefe4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,10 +39,10 @@ dashmap = "6.1.0" serde = { version = "1.0", features = ["derive", "alloc"] } thiserror = "2.0" -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" } -p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" } -p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } +p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } [dev-dependencies] criterion = "0.7" diff --git a/src/lib.rs b/src/lib.rs index d7457f9..e29172e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +use p3_field::Field; use p3_koala_bear::{ KoalaBear, Poseidon2KoalaBear, default_koalabear_poseidon2_16, default_koalabear_poseidon2_24, }; @@ -11,6 +12,7 @@ pub const TWEAK_SEPARATOR_FOR_TREE_HASH: u8 = 0x01; pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00; type F = KoalaBear; +pub(crate) type PackedF = ::Packing; pub(crate) mod hypercube; pub(crate) mod inc_encoding; diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 9f85aa0..529ed98 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -205,38 +205,15 @@ where let chain_length = IE::BASE; // the range of epochs covered by that bottom tree - let epoch_range_start = bottom_tree_index * leafs_per_bottom_tree; - let epoch_range_end = epoch_range_start + leafs_per_bottom_tree; - let epoch_range = epoch_range_start..epoch_range_end; - - // parallelize the chain ends hash computation for each epoch in the interval for that bottom tree - let chain_ends_hashes = epoch_range - .into_par_iter() - .map(|epoch| { - // each epoch has a number of chains - // parallelize the chain ends computation for each chain - let chain_ends = (0..num_chains) - .into_par_iter() - .map(|chain_index| { - // each chain start is just a PRF evaluation - let start = - PRF::get_domain_element(prf_key, epoch as u32, chain_index as u64).into(); - // walk the chain to get the public chain end - chain::( - parameter, - epoch as u32, - chain_index as u8, - 0, - chain_length - 1, - &start, - ) - }) - .collect::>(); - // build hash of chain ends / public keys - TH::apply(parameter, &TH::tree_tweak(0, epoch as u32), &chain_ends) - }) + let epoch_start = bottom_tree_index * leafs_per_bottom_tree; + let epochs: Vec = (epoch_start..epoch_start + leafs_per_bottom_tree) + .map(|e| e as u32) .collect(); + // Compute chain ends for all epochs. + let chain_ends_hashes = + TH::compute_tree_leaves::(prf_key, parameter, &epochs, num_chains, chain_length); + // now that we have the hashes of all chain ends (= leafs of our tree), we can compute the bottom tree HashSubTree::new_bottom_tree( LOG_LIFETIME, diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index f713c3b..2e7fd99 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -164,7 +164,7 @@ where .copied() .collect(); - let hash_fe = poseidon_compress::<_, 24, HASH_LEN_FE>(&perm, &combined_input_vec); + let hash_fe = poseidon_compress::(&perm, &combined_input_vec); // decode field elements into chunks and return them decode_to_chunks::(&hash_fe).to_vec() diff --git a/src/symmetric/message_hash/top_level_poseidon.rs b/src/symmetric/message_hash/top_level_poseidon.rs index d8d7c39..6ea7998 100644 --- a/src/symmetric/message_hash/top_level_poseidon.rs +++ b/src/symmetric/message_hash/top_level_poseidon.rs @@ -159,7 +159,7 @@ where .collect(); let iteration_pos_output = - poseidon_compress::<_, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input); + poseidon_compress::(&perm, &combined_input); pos_outputs[i * POS_OUTPUT_LEN_PER_INV_FE..(i + 1) * POS_OUTPUT_LEN_PER_INV_FE] .copy_from_slice(&iteration_pos_output); diff --git a/src/symmetric/tweak_hash.rs b/src/symmetric/tweak_hash.rs index 61144a6..2add975 100644 --- a/src/symmetric/tweak_hash.rs +++ b/src/symmetric/tweak_hash.rs @@ -1,6 +1,9 @@ use rand::Rng; +use rayon::prelude::*; use serde::{Serialize, de::DeserializeOwned}; +use crate::symmetric::prf::Pseudorandom; + /// Trait to model a tweakable hash function. /// Such a function takes a public parameter, a tweak, and a /// message to be hashed. The tweak should be understood as an @@ -14,8 +17,13 @@ use serde::{Serialize, de::DeserializeOwned}; /// to obtain distinct tweaks for applications in chains and /// applications in Merkle trees. pub trait TweakableHash { + /// Public parameter type for the hash function type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned; + + /// Tweak type for domain separation type Tweak; + + /// Domain element type (defines output and input types to the hash) type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned; /// Generates a random public parameter. @@ -39,8 +47,50 @@ pub trait TweakableHash { message: &[Self::Domain], ) -> Self::Domain; - /// Function to check internal consistency of any given parameters - /// For testing only, and expected to panic if something is wrong. + /// Computes bottom tree leaves by walking hash chains for multiple epochs. + /// + /// This method has a default scalar implementation that processes epochs in parallel. + fn compute_tree_leaves( + prf_key: &PRF::Key, + parameter: &Self::Parameter, + epochs: &[u32], + num_chains: usize, + chain_length: usize, + ) -> Vec + where + PRF: Pseudorandom, + PRF::Domain: Into, + Self: Sized, + { + // Default scalar implementation: process each epoch in parallel + epochs + .par_iter() + .map(|&epoch| { + // For each epoch, walk all chains in parallel + let chain_ends: Vec<_> = (0..num_chains) + .into_par_iter() + .map(|chain_index| { + let start = + PRF::get_domain_element(prf_key, epoch, chain_index as u64).into(); + chain::( + parameter, + epoch, + chain_index as u8, + 0, + chain_length - 1, + &start, + ) + }) + .collect(); + // Hash all chain ends together to get the leaf + Self::apply(parameter, &Self::tree_tweak(0, epoch), &chain_ends) + }) + .collect() + } + + /// Function to check internal consistency of any given parameters. + /// + /// This is for testing only and is expected to panic if something is wrong. #[cfg(test)] fn internal_consistency_check(); } @@ -77,7 +127,6 @@ pub mod poseidon; #[cfg(test)] mod tests { - use crate::symmetric::tweak_hash::poseidon::PoseidonTweak44; use super::*; diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index ca454f7..77d4b43 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -1,13 +1,17 @@ -use p3_field::PrimeCharacteristicRing; -use p3_field::PrimeField64; +use core::array; + +use p3_field::{Algebra, PackedValue, PrimeCharacteristicRing, PrimeField64}; use p3_symmetric::Permutation; +use rayon::prelude::*; use serde::{Serialize, de::DeserializeOwned}; -use crate::F; use crate::TWEAK_SEPARATOR_FOR_CHAIN_HASH; use crate::TWEAK_SEPARATOR_FOR_TREE_HASH; use crate::poseidon2_16; use crate::poseidon2_24; +use crate::symmetric::prf::Pseudorandom; +use crate::symmetric::tweak_hash::chain; +use crate::{F, PackedF}; use super::TweakableHash; @@ -17,6 +21,65 @@ const CHAIN_COMPRESSION_WIDTH: usize = 16; /// The state width for merging two hashes in a tree or for the sponge construction. const MERGE_COMPRESSION_WIDTH: usize = 24; +/// Packs scalar arrays into SIMD-friendly vertical layout. +/// +/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`. +/// +/// Input layout (horizontal): each row is one complete array +/// ```text +/// data[0] = [a0, a1, a2, ..., aN] +/// data[1] = [b0, b1, b2, ..., bN] +/// data[2] = [c0, c1, c2, ..., cN] +/// ... +/// ``` +/// +/// Output layout (vertical): each PackedF holds one element from each array +/// ```text +/// result[0] = PackedF([a0, b0, c0, ...]) // All first elements +/// result[1] = PackedF([a1, b1, c1, ...]) // All second elements +/// result[2] = PackedF([a2, b2, c2, ...]) // All third elements +/// ... +/// ``` +/// +/// This vertical packing enables efficient SIMD operations where a single instruction +/// processes the same element position across multiple arrays simultaneously. +#[inline] +fn pack_array(data: &[[F; N]]) -> [PackedF; N] { + array::from_fn(|i| PackedF::from_fn(|j| data[j][i])) +} + +/// Unpacks SIMD vertical layout back into scalar arrays. +/// +/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`. +/// +/// This is the inverse operation of `pack_array`. The output buffer must be preallocated +/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`. +/// +/// Input layout (vertical): each PackedF holds one element from each array +/// ```text +/// packed_data[0] = PackedF([a0, b0, c0, ...]) +/// packed_data[1] = PackedF([a1, b1, c1, ...]) +/// packed_data[2] = PackedF([a2, b2, c2, ...]) +/// ... +/// ``` +/// +/// Output layout (horizontal): each row is one complete array +/// ```text +/// output[0] = [a0, a1, a2, ..., aN] +/// output[1] = [b0, b1, b2, ..., bN] +/// output[2] = [c0, c1, c2, ..., cN] +/// ... +/// ``` +#[inline] +fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F; N]]) { + for (i, data) in packed_data.iter().enumerate().take(N) { + let unpacked_v = data.as_slice(); + for j in 0..PackedF::WIDTH { + output[j][i] = unpacked_v[j]; + } + } +} + /// Enum to implement tweaks. pub enum PoseidonTweak { TreeTweak { @@ -63,17 +126,20 @@ impl PoseidonTweak { } } -/// Poseidon Compression Function. +/// Poseidon Compression Function /// /// Computes: /// PoseidonCompress(x) = Truncate(PoseidonPermute(x) + x) /// -/// This function takes an input slice `x`, applies the Poseidon permutation, -/// adds the original input back (as a feed-forward), and returns the first `OUT_LEN` elements. +/// This function works generically over `A: Algebra`, allowing it to process both: +/// - Scalar fields, +/// - Packed SIMD fields +/// +/// This follows the Plonky3 pattern that enables automatic SIMD optimization. /// /// - `WIDTH`: total state width (input length to permutation). /// - `OUT_LEN`: number of output elements to return. -/// - `perm`: a Poseidon permutation over `[F; WIDTH]`. +/// - `perm`: a Poseidon permutation over `[A; WIDTH]`. /// - `input`: slice of input values, must be `≤ WIDTH` and `≥ OUT_LEN`. /// /// ### Warning: Input Padding @@ -88,12 +154,13 @@ impl PoseidonTweak { /// Panics: /// - If `input.len() < OUT_LEN` /// - If `OUT_LEN > WIDTH` -pub fn poseidon_compress( +pub fn poseidon_compress( perm: &P, - input: &[F], -) -> [F; OUT_LEN] + input: &[A], +) -> [A; OUT_LEN] where - P: Permutation<[F; WIDTH]>, + A: Algebra + Copy, + P: Permutation<[A; WIDTH]>, { assert!( input.len() >= OUT_LEN, @@ -101,7 +168,7 @@ where ); // Copy the input into a fixed-width buffer, zero-padding unused elements if any. - let mut padded_input = [F::ZERO; WIDTH]; + let mut padded_input = [A::ZERO; WIDTH]; padded_input[..input.len()].copy_from_slice(input); // Start with the input as the initial state. @@ -124,18 +191,23 @@ where /// Computes a Poseidon-based domain separator by compressing an array of `u32` /// values using a fixed Poseidon instance. /// +/// This function works generically over `A: Algebra`, allowing it to process both: +/// - Scalar fields, +/// - Packed SIMD fields +/// /// ### Usage constraints /// - This function is private because it's tailored to a very specific case: /// the Poseidon2 instance with arity 24 and a fixed 4-word input. /// - As this function operates on constants, its output can be **precomputed** /// for significant performance gains, especially within a circuit. /// - If generalization is ever needed, a more generic and slower version should be used. -fn poseidon_safe_domain_separator( +fn poseidon_safe_domain_separator( perm: &P, params: &[u32; DOMAIN_PARAMETERS_LENGTH], -) -> [F; OUT_LEN] +) -> [A; OUT_LEN] where - P: Permutation<[F; WIDTH]>, + A: Algebra + Copy, + P: Permutation<[A; WIDTH]>, { // Combine params into a single number in base 2^32 // @@ -153,37 +225,44 @@ where let input = std::array::from_fn::<_, 24, _>(|_| { let digit = (acc % F::ORDER_U64 as u128) as u64; acc /= F::ORDER_U64 as u128; - F::from_u64(digit) + A::from_u64(digit) }); - poseidon_compress::<_, WIDTH, OUT_LEN>(perm, &input) + poseidon_compress::(perm, &input) } -/// Poseidon Sponge Hash Function. +/// Poseidon Sponge Hash Function /// /// Absorbs an arbitrary-length input using the Poseidon sponge construction /// and outputs `OUT_LEN` field elements. Domain separation is achieved by /// injecting a `capacity_value` into the state. /// +/// This function works generically over `A: Algebra`, allowing it to process both: +/// - Scalar fields, +/// - Packed SIMD fields +/// +/// ### Parameters /// - `WIDTH`: sponge state width. /// - `OUT_LEN`: number of output elements. -/// - `perm`: Poseidon permutation over `[F; WIDTH]`. -/// - `capacity_value`: values to occupy the capacity part of the state (must be ≤ `WIDTH`). +/// - `perm`: Poseidon permutation over `[A; WIDTH]`. +/// - `capacity_value`: values to occupy the capacity part of the state (must be < `WIDTH`). /// - `input`: message to hash (any length). /// +/// ### Sponge Construction /// This follows the classic sponge structure: -/// - Absorption: inputs are added chunk-by-chunk into the first `rate` elements of the state. -/// - Squeezing: outputs are read from the first `rate` elements of the state, permuted as needed. +/// - **Absorption**: inputs are added chunk-by-chunk into the first `rate` elements of the state. +/// - **Squeezing**: outputs are read from the first `rate` elements of the state, permuted as needed. /// -/// Panics: +/// ### Panics /// - If `capacity_value.len() >= WIDTH` -fn poseidon_sponge( +fn poseidon_sponge( perm: &P, - capacity_value: &[F], - input: &[F], -) -> [F; OUT_LEN] + capacity_value: &[A], + input: &[A], +) -> [A; OUT_LEN] where - P: Permutation<[F; WIDTH]>, + A: Algebra + Copy, + P: Permutation<[A; WIDTH]>, { // The capacity length must be strictly smaller than the width to have a non-zero rate. // This check prevents a panic from subtraction underflow when calculating the rate. @@ -199,10 +278,10 @@ where // // This is safe because the input's original length is effectively encoded // in the `capacity_value`, which serves as a domain separator. - input_vector.resize(input.len() + extra_elements, F::ZERO); + input_vector.resize(input.len() + extra_elements, A::ZERO); // initialize - let mut state = [F::ZERO; WIDTH]; + let mut state = [A::ZERO; WIDTH]; state[rate..].copy_from_slice(capacity_value); // absorb @@ -297,7 +376,7 @@ where .chain(single.iter()) .copied() .collect(); - poseidon_compress::<_, CHAIN_COMPRESSION_WIDTH, HASH_LEN>(&perm, &combined_input) + poseidon_compress::(&perm, &combined_input) } [left, right] => { @@ -310,7 +389,7 @@ where .chain(right.iter()) .copied() .collect(); - poseidon_compress::<_, MERGE_COMPRESSION_WIDTH, HASH_LEN>(&perm, &combined_input) + poseidon_compress::(&perm, &combined_input) } _ if message.len() > 2 => { @@ -330,10 +409,10 @@ where HASH_LEN as u32, ]; let capacity_value = - poseidon_safe_domain_separator::<_, MERGE_COMPRESSION_WIDTH, CAPACITY>( + poseidon_safe_domain_separator::( &perm, &lengths, ); - poseidon_sponge::<_, MERGE_COMPRESSION_WIDTH, HASH_LEN>( + poseidon_sponge::( &perm, &capacity_value, &combined_input, @@ -343,6 +422,223 @@ where } } + fn compute_tree_leaves( + prf_key: &PRF::Key, + parameter: &Self::Parameter, + epochs: &[u32], + num_chains: usize, + chain_length: usize, + ) -> Vec + where + PRF: Pseudorandom, + PRF::Domain: Into, + { + // Verify that num_chains matches the encoding dimension. + assert_eq!( + num_chains, NUM_CHUNKS, + "Poseidon SIMD implementation requires num_chains == NUM_CHUNKS. Got num_chains={}, NUM_CHUNKS={}", + num_chains, NUM_CHUNKS + ); + + // SIMD-ACCELERATED IMPLEMENTATION + // + // This path leverages architecture-specific SIMD instructions. + // `PackedF` represents multiple field elements processed in parallel. + // + // The key point: process multiple epochs simultaneously using SIMD. + // Each SIMD lane corresponds to one epoch. + + // Determine SIMD width based on architecture. + let width = PackedF::WIDTH; + + // Allocate output buffer for all leaves. + let mut leaves = vec![[F::ZERO; HASH_LEN]; epochs.len()]; + + // PREPARE PACKED CONSTANTS + + // Broadcast the hash parameter to all SIMD lanes. + // Each lane will use the same parameter for its epoch. + let packed_parameter: [PackedF; PARAMETER_LEN] = + array::from_fn(|i| PackedF::from(parameter[i])); + + // Create Poseidon permutation instances. + // - Width-16 for chain compression, + // - Width-24 for sponge hashing. + let chain_perm = poseidon2_16(); + let sponge_perm = poseidon2_24(); + + // Compute domain separator for the sponge construction. + // This ensures different use cases produce different outputs. + let lengths = [ + PARAMETER_LEN as u32, + TWEAK_LEN as u32, + NUM_CHUNKS as u32, + HASH_LEN as u32, + ]; + let capacity_val = + poseidon_safe_domain_separator::( + &sponge_perm, + &lengths, + ); + + // PARALLEL SIMD PROCESSING + // + // Process epochs in batches of size `width`. + // Each batch is handled by one thread. + // Within each batch, SIMD processes `width` epochs simultaneously. + epochs + .par_chunks_exact(width) + .zip(leaves.par_chunks_exact_mut(width)) + .for_each(|(epoch_chunk, leaves_chunk)| { + // STEP 1: GENERATE AND PACK CHAIN STARTING POINTS + // + // For each chain, generate starting points for all epochs in the chunk. + // Use vertical packing: transpose from [lane][element] to [element][lane]. + // + // This layout enables efficient SIMD operations across epochs. + + let mut packed_chains: [[PackedF; HASH_LEN]; NUM_CHUNKS] = + array::from_fn(|c_idx| { + // Generate starting points for this chain across all epochs. + let starts: [[F; HASH_LEN]; PackedF::WIDTH] = array::from_fn(|lane| { + PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64).into() + }); + + // Transpose to vertical packing for SIMD efficiency. + pack_array(&starts) + }); + + // STEP 2: WALK CHAINS IN PARALLEL USING SIMD + // + // For each chain, walk all epochs simultaneously using SIMD. + // The chains start at their initial values and are walked step-by-step + // until they reach their endpoints. + // + // Cache strategy: process one chain at a time to maximize locality. + // All epochs for that chain stay in registers across iterations. + + for (chain_index, packed_chain) in + packed_chains.iter_mut().enumerate().take(num_chains) + { + // Walk this chain for `chain_length - 1` steps. + // The starting point is step 0, so we need `chain_length - 1` iterations. + for step in 0..chain_length - 1 { + // Current position in the chain. + let pos = (step + 1) as u8; + + // Generate tweaks for all epochs in this SIMD batch. + // Each lane gets a tweak specific to its epoch. + let packed_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| { + PackedF::from_fn(|lane| { + Self::chain_tweak(epoch_chunk[lane], chain_index as u8, pos) + .to_field_elements::()[t_idx] + }) + }); + + // Assemble the packed input for the hash function. + // Layout: [parameter | tweak | current_value] + let mut packed_input = [PackedF::ZERO; CHAIN_COMPRESSION_WIDTH]; + let mut current_pos = 0; + + // Copy parameter into the input buffer. + packed_input[current_pos..current_pos + PARAMETER_LEN] + .copy_from_slice(&packed_parameter); + current_pos += PARAMETER_LEN; + + // Copy tweak into the input buffer. + packed_input[current_pos..current_pos + TWEAK_LEN] + .copy_from_slice(&packed_tweak); + current_pos += TWEAK_LEN; + + // Copy current chain value into the input buffer. + packed_input[current_pos..current_pos + HASH_LEN] + .copy_from_slice(packed_chain); + + // Apply the hash function to advance the chain. + // This single call processes all epochs in parallel. + *packed_chain = + poseidon_compress::( + &chain_perm, + &packed_input, + ); + } + } + + // STEP 3: HASH CHAIN ENDS TO PRODUCE TREE LEAVES + // + // All chains have been walked to their endpoints. + // Now hash all chain ends together to form the tree leaf. + // + // This uses the sponge construction for variable-length input. + + // Generate tree tweaks for all epochs. + // Level 0 indicates this is a bottom-layer leaf in the tree. + let packed_tree_tweak = array::from_fn::<_, TWEAK_LEN, _>(|t_idx| { + PackedF::from_fn(|lane| { + Self::tree_tweak(0, epoch_chunk[lane]).to_field_elements::() + [t_idx] + }) + }); + + // Assemble the sponge input. + // Layout: [parameter | tree_tweak | all_chain_ends] + let packed_leaf_input: Vec<_> = packed_parameter + .iter() + .chain(packed_tree_tweak.iter()) + .chain(packed_chains.iter().flatten()) + .copied() + .collect(); + + // Apply the sponge hash to produce the leaf. + // This absorbs all chain ends and squeezes out the final hash. + let packed_leaves = poseidon_sponge::( + &sponge_perm, + &capacity_val, + &packed_leaf_input, + ); + + // STEP 4: UNPACK RESULTS TO SCALAR REPRESENTATION + // + // Convert from vertical packing back to scalar layout. + // Each lane becomes one leaf in the output slice. + + unpack_array(&packed_leaves, leaves_chunk); + }); + + // HANDLE REMAINDER EPOCHS + // + // If the total number of epochs is not divisible by the SIMD width, + // process the remaining epochs using scalar code. + // + // This ensures correctness for all input sizes. + + let remainder_start = (epochs.len() / width) * width; + for (i, epoch) in epochs[remainder_start..].iter().enumerate() { + let global_index = remainder_start + i; + + // Walk all chains for this epoch. + let chain_ends: Vec<_> = (0..NUM_CHUNKS) + .map(|chain_index| { + let start = PRF::get_domain_element(prf_key, *epoch, chain_index as u64).into(); + chain::( + parameter, + *epoch, + chain_index as u8, + 0, + chain_length - 1, + &start, + ) + }) + .collect(); + + // Hash the chain ends to produce the leaf. + leaves[global_index] = + Self::apply(parameter, &Self::tree_tweak(0, *epoch), &chain_ends); + } + + leaves + } + #[cfg(test)] fn internal_consistency_check() { assert!( From 825a2ecfd5e04b5c0f9e30090c0d04cef83960c6 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 17 Nov 2025 19:54:33 +0100 Subject: [PATCH 2/9] rm useless default compute_tree_leaves --- src/symmetric/tweak_hash.rs | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/src/symmetric/tweak_hash.rs b/src/symmetric/tweak_hash.rs index 2add975..dae6d1c 100644 --- a/src/symmetric/tweak_hash.rs +++ b/src/symmetric/tweak_hash.rs @@ -1,5 +1,4 @@ use rand::Rng; -use rayon::prelude::*; use serde::{Serialize, de::DeserializeOwned}; use crate::symmetric::prf::Pseudorandom; @@ -60,33 +59,7 @@ pub trait TweakableHash { where PRF: Pseudorandom, PRF::Domain: Into, - Self: Sized, - { - // Default scalar implementation: process each epoch in parallel - epochs - .par_iter() - .map(|&epoch| { - // For each epoch, walk all chains in parallel - let chain_ends: Vec<_> = (0..num_chains) - .into_par_iter() - .map(|chain_index| { - let start = - PRF::get_domain_element(prf_key, epoch, chain_index as u64).into(); - chain::( - parameter, - epoch, - chain_index as u8, - 0, - chain_length - 1, - &start, - ) - }) - .collect(); - // Hash all chain ends together to get the leaf - Self::apply(parameter, &Self::tree_tweak(0, epoch), &chain_ends) - }) - .collect() - } + Self: Sized; /// Function to check internal consistency of any given parameters. /// From 8b4f5e8d7c673db070f3640207d6a245b83dd99c Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 17 Nov 2025 21:04:18 +0100 Subject: [PATCH 3/9] mv simd utils to a specific file --- src/symmetric.rs | 1 + src/symmetric/simd_utils.rs | 178 +++++++++++++++++++++++++++ src/symmetric/tweak_hash/poseidon.rs | 60 +-------- 3 files changed, 180 insertions(+), 59 deletions(-) create mode 100644 src/symmetric/simd_utils.rs diff --git a/src/symmetric.rs b/src/symmetric.rs index c688b38..2f4e422 100644 --- a/src/symmetric.rs +++ b/src/symmetric.rs @@ -1,4 +1,5 @@ pub mod message_hash; pub mod prf; +pub mod simd_utils; pub mod tweak_hash; pub mod tweak_hash_tree; diff --git a/src/symmetric/simd_utils.rs b/src/symmetric/simd_utils.rs new file mode 100644 index 0000000..435422e --- /dev/null +++ b/src/symmetric/simd_utils.rs @@ -0,0 +1,178 @@ +use core::array; + +use p3_field::PackedValue; + +use crate::{F, PackedF}; + +/// Packs scalar arrays into SIMD-friendly vertical layout. +/// +/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`. +/// +/// Input layout (horizontal): each row is one complete array +/// ```text +/// data[0] = [a0, a1, a2, ..., aN] +/// data[1] = [b0, b1, b2, ..., bN] +/// data[2] = [c0, c1, c2, ..., cN] +/// ... +/// ``` +/// +/// Output layout (vertical): each PackedF holds one element from each array +/// ```text +/// result[0] = PackedF([a0, b0, c0, ...]) // All first elements +/// result[1] = PackedF([a1, b1, c1, ...]) // All second elements +/// result[2] = PackedF([a2, b2, c2, ...]) // All third elements +/// ... +/// ``` +/// +/// This vertical packing enables efficient SIMD operations where a single instruction +/// processes the same element position across multiple arrays simultaneously. +#[inline] +pub fn pack_array(data: &[[F; N]]) -> [PackedF; N] { + array::from_fn(|i| PackedF::from_fn(|j| data[j][i])) +} + +/// Unpacks SIMD vertical layout back into scalar arrays. +/// +/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`. +/// +/// This is the inverse operation of `pack_array`. The output buffer must be preallocated +/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`. +/// +/// Input layout (vertical): each PackedF holds one element from each array +/// ```text +/// packed_data[0] = PackedF([a0, b0, c0, ...]) +/// packed_data[1] = PackedF([a1, b1, c1, ...]) +/// packed_data[2] = PackedF([a2, b2, c2, ...]) +/// ... +/// ``` +/// +/// Output layout (horizontal): each row is one complete array +/// ```text +/// output[0] = [a0, a1, a2, ..., aN] +/// output[1] = [b0, b1, b2, ..., bN] +/// output[2] = [c0, c1, c2, ..., cN] +/// ... +/// ``` +#[inline] +pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F; N]]) { + for (i, data) in packed_data.iter().enumerate().take(N) { + let unpacked_v = data.as_slice(); + for j in 0..PackedF::WIDTH { + output[j][i] = unpacked_v[j]; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use p3_field::PrimeCharacteristicRing; + use proptest::prelude::*; + use rand::Rng; + + #[test] + fn test_pack_array_simple() { + // Test with N=2 (2 field elements per array) + // Create WIDTH arrays of [F; 2] + let data: [[F; 2]; PackedF::WIDTH] = + array::from_fn(|i| [F::from_u64(i as u64), F::from_u64((i + 100) as u64)]); + + let packed = pack_array(&data); + + // Check that packed[0] contains all first elements + for (lane, &expected) in data.iter().enumerate() { + assert_eq!(packed[0].as_slice()[lane], expected[0]); + } + + // Check that packed[1] contains all second elements + for (lane, &expected) in data.iter().enumerate() { + assert_eq!(packed[1].as_slice()[lane], expected[1]); + } + } + + #[test] + fn test_unpack_array_simple() { + // Create packed data + let packed: [PackedF; 2] = [ + PackedF::from_fn(|i| F::from_u64(i as u64)), + PackedF::from_fn(|i| F::from_u64((i + 100) as u64)), + ]; + + // Unpack + let mut output = [[F::ZERO; 2]; PackedF::WIDTH]; + unpack_array(&packed, &mut output); + + // Verify + for (lane, arr) in output.iter().enumerate() { + assert_eq!(arr[0], F::from_u64(lane as u64)); + assert_eq!(arr[1], F::from_u64((lane + 100) as u64)); + } + } + + #[test] + fn test_pack_preserves_element_order() { + // Create data where each array has sequential values + let data: [[F; 3]; PackedF::WIDTH] = array::from_fn(|i| { + [ + F::from_u64((i * 3) as u64), + F::from_u64((i * 3 + 1) as u64), + F::from_u64((i * 3 + 2) as u64), + ] + }); + + let packed = pack_array(&data); + + // Verify the packing structure + // packed[0] should contain: [0, 3, 6, 9, ...] + // packed[1] should contain: [1, 4, 7, 10, ...] + // packed[2] should contain: [2, 5, 8, 11, ...] + for (element_idx, p) in packed.iter().enumerate() { + for lane in 0..PackedF::WIDTH { + let expected = F::from_u64((lane * 3 + element_idx) as u64); + assert_eq!(p.as_slice()[lane], expected); + } + } + } + + #[test] + fn test_unpack_preserves_element_order() { + // Create packed data with known pattern + let packed: [PackedF; 3] = [ + PackedF::from_fn(|i| F::from_u64((i * 3) as u64)), + PackedF::from_fn(|i| F::from_u64((i * 3 + 1) as u64)), + PackedF::from_fn(|i| F::from_u64((i * 3 + 2) as u64)), + ]; + + let mut output = [[F::ZERO; 3]; PackedF::WIDTH]; + unpack_array(&packed, &mut output); + + // Verify each array has sequential values + for (lane, arr) in output.iter().enumerate() { + assert_eq!(arr[0], F::from_u64((lane * 3) as u64)); + assert_eq!(arr[1], F::from_u64((lane * 3 + 1) as u64)); + assert_eq!(arr[2], F::from_u64((lane * 3 + 2) as u64)); + } + } + + proptest! { + #[test] + fn proptest_pack_unpack_roundtrip( + _seed in any::() + ) { + let mut rng = rand::rng(); + + // Generate random data with N=10 + let original: [[F; 10]; PackedF::WIDTH] = array::from_fn(|_| { + array::from_fn(|_| rng.random()) + }); + + // Pack and unpack + let packed = pack_array(&original); + let mut unpacked = [[F::ZERO; 10]; PackedF::WIDTH]; + unpack_array(&packed, &mut unpacked); + + // Verify roundtrip + prop_assert_eq!(original, unpacked); + } + } +} diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 77d4b43..2f6c87e 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -10,6 +10,7 @@ use crate::TWEAK_SEPARATOR_FOR_TREE_HASH; use crate::poseidon2_16; use crate::poseidon2_24; use crate::symmetric::prf::Pseudorandom; +use crate::symmetric::simd_utils::{pack_array, unpack_array}; use crate::symmetric::tweak_hash::chain; use crate::{F, PackedF}; @@ -21,65 +22,6 @@ const CHAIN_COMPRESSION_WIDTH: usize = 16; /// The state width for merging two hashes in a tree or for the sponge construction. const MERGE_COMPRESSION_WIDTH: usize = 24; -/// Packs scalar arrays into SIMD-friendly vertical layout. -/// -/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`. -/// -/// Input layout (horizontal): each row is one complete array -/// ```text -/// data[0] = [a0, a1, a2, ..., aN] -/// data[1] = [b0, b1, b2, ..., bN] -/// data[2] = [c0, c1, c2, ..., cN] -/// ... -/// ``` -/// -/// Output layout (vertical): each PackedF holds one element from each array -/// ```text -/// result[0] = PackedF([a0, b0, c0, ...]) // All first elements -/// result[1] = PackedF([a1, b1, c1, ...]) // All second elements -/// result[2] = PackedF([a2, b2, c2, ...]) // All third elements -/// ... -/// ``` -/// -/// This vertical packing enables efficient SIMD operations where a single instruction -/// processes the same element position across multiple arrays simultaneously. -#[inline] -fn pack_array(data: &[[F; N]]) -> [PackedF; N] { - array::from_fn(|i| PackedF::from_fn(|j| data[j][i])) -} - -/// Unpacks SIMD vertical layout back into scalar arrays. -/// -/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`. -/// -/// This is the inverse operation of `pack_array`. The output buffer must be preallocated -/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`. -/// -/// Input layout (vertical): each PackedF holds one element from each array -/// ```text -/// packed_data[0] = PackedF([a0, b0, c0, ...]) -/// packed_data[1] = PackedF([a1, b1, c1, ...]) -/// packed_data[2] = PackedF([a2, b2, c2, ...]) -/// ... -/// ``` -/// -/// Output layout (horizontal): each row is one complete array -/// ```text -/// output[0] = [a0, a1, a2, ..., aN] -/// output[1] = [b0, b1, b2, ..., bN] -/// output[2] = [c0, c1, c2, ..., cN] -/// ... -/// ``` -#[inline] -fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F; N]]) { - for (i, data) in packed_data.iter().enumerate().take(N) { - let unpacked_v = data.as_slice(); - for j in 0..PackedF::WIDTH { - output[j][i] = unpacked_v[j]; - } - } -} - /// Enum to implement tweaks. pub enum PoseidonTweak { TreeTweak { From 1b7f546020506f32dd2347b5e4f6a1034ebb9626 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 17 Nov 2025 21:24:48 +0100 Subject: [PATCH 4/9] bench --- benches/benchmark.rs | 2 +- benches/benchmark_poseidon_top_level.rs | 64 ++++++++++++------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 235e803..4c01b4c 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -9,6 +9,6 @@ use benchmark_poseidon_top_level::bench_function_poseidon_top_level; criterion_group!( benches, bench_function_poseidon_top_level, - bench_function_poseidon + // bench_function_poseidon ); criterion_main!(benches); diff --git a/benches/benchmark_poseidon_top_level.rs b/benches/benchmark_poseidon_top_level.rs index aa5f235..f2bdb25 100644 --- a/benches/benchmark_poseidon_top_level.rs +++ b/benches/benchmark_poseidon_top_level.rs @@ -39,7 +39,7 @@ pub fn benchmark_signature_scheme(c: &mut Criterion, descrip // Note: benchmarking key generation takes long, so it is // commented out for now. You can enable it here. - #[cfg(feature = "with-gen-benches-poseidon-top-level")] + // #[cfg(feature = "with-gen-benches-poseidon-top-level")] group.bench_function("- gen", |b| { b.iter(|| { // Benchmark key generation @@ -108,35 +108,35 @@ pub fn bench_function_poseidon_top_level(c: &mut Criterion) { ), ); - // benchmarking lifetime 2^18 - benchmark_signature_scheme::( - c, - &format!( - "Top Level TS, Lifetime 2^18, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8" - ), - ); - - // benchmarking lifetime 2^32 - hashing optimized - benchmark_signature_scheme::( - c, - &format!( - "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8 (Hashing Optimized)" - ), - ); - - // benchmarking lifetime 2^32 - trade-off - benchmark_signature_scheme::( - c, - &format!( - "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 48, Base 10 (Trade-off)" - ), - ); - - // benchmarking lifetime 2^32 - size optimized - benchmark_signature_scheme::( - c, - &format!( - "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 32, Base 26 (Size Optimized)" - ), - ); + // // benchmarking lifetime 2^18 + // benchmark_signature_scheme::( + // c, + // &format!( + // "Top Level TS, Lifetime 2^18, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8" + // ), + // ); + + // // benchmarking lifetime 2^32 - hashing optimized + // benchmark_signature_scheme::( + // c, + // &format!( + // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8 (Hashing Optimized)" + // ), + // ); + + // // benchmarking lifetime 2^32 - trade-off + // benchmark_signature_scheme::( + // c, + // &format!( + // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 48, Base 10 (Trade-off)" + // ), + // ); + + // // benchmarking lifetime 2^32 - size optimized + // benchmark_signature_scheme::( + // c, + // &format!( + // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 32, Base 26 (Size Optimized)" + // ), + // ); } From efe2d0cc5a9925ec731f0e4316712c0966c95eb4 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 17 Nov 2025 22:02:07 +0100 Subject: [PATCH 5/9] fix bench --- benches/benchmark.rs | 2 +- benches/benchmark_poseidon_top_level.rs | 64 ++++++++++++------------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 4c01b4c..235e803 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -9,6 +9,6 @@ use benchmark_poseidon_top_level::bench_function_poseidon_top_level; criterion_group!( benches, bench_function_poseidon_top_level, - // bench_function_poseidon + bench_function_poseidon ); criterion_main!(benches); diff --git a/benches/benchmark_poseidon_top_level.rs b/benches/benchmark_poseidon_top_level.rs index f2bdb25..aa5f235 100644 --- a/benches/benchmark_poseidon_top_level.rs +++ b/benches/benchmark_poseidon_top_level.rs @@ -39,7 +39,7 @@ pub fn benchmark_signature_scheme(c: &mut Criterion, descrip // Note: benchmarking key generation takes long, so it is // commented out for now. You can enable it here. - // #[cfg(feature = "with-gen-benches-poseidon-top-level")] + #[cfg(feature = "with-gen-benches-poseidon-top-level")] group.bench_function("- gen", |b| { b.iter(|| { // Benchmark key generation @@ -108,35 +108,35 @@ pub fn bench_function_poseidon_top_level(c: &mut Criterion) { ), ); - // // benchmarking lifetime 2^18 - // benchmark_signature_scheme::( - // c, - // &format!( - // "Top Level TS, Lifetime 2^18, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8" - // ), - // ); - - // // benchmarking lifetime 2^32 - hashing optimized - // benchmark_signature_scheme::( - // c, - // &format!( - // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8 (Hashing Optimized)" - // ), - // ); - - // // benchmarking lifetime 2^32 - trade-off - // benchmark_signature_scheme::( - // c, - // &format!( - // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 48, Base 10 (Trade-off)" - // ), - // ); - - // // benchmarking lifetime 2^32 - size optimized - // benchmark_signature_scheme::( - // c, - // &format!( - // "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 32, Base 26 (Size Optimized)" - // ), - // ); + // benchmarking lifetime 2^18 + benchmark_signature_scheme::( + c, + &format!( + "Top Level TS, Lifetime 2^18, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8" + ), + ); + + // benchmarking lifetime 2^32 - hashing optimized + benchmark_signature_scheme::( + c, + &format!( + "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 64, Base 8 (Hashing Optimized)" + ), + ); + + // benchmarking lifetime 2^32 - trade-off + benchmark_signature_scheme::( + c, + &format!( + "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 48, Base 10 (Trade-off)" + ), + ); + + // benchmarking lifetime 2^32 - size optimized + benchmark_signature_scheme::( + c, + &format!( + "Top Level TS, Lifetime 2^32, Activation 2^{MAX_LOG_ACTIVATION_DURATION}, Dimension 32, Base 26 (Size Optimized)" + ), + ); } From 7fd3cc6c8c3f81734cb328bf3e4e2af52572667a Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Wed, 19 Nov 2025 10:00:56 +0100 Subject: [PATCH 6/9] fix Angus comments --- src/symmetric/tweak_hash/poseidon.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 2f6c87e..30a93e6 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -1,7 +1,7 @@ use core::array; use p3_field::{Algebra, PackedValue, PrimeCharacteristicRing, PrimeField64}; -use p3_symmetric::Permutation; +use p3_symmetric::CryptographicPermutation; use rayon::prelude::*; use serde::{Serialize, de::DeserializeOwned}; @@ -73,7 +73,7 @@ impl PoseidonTweak { /// Computes: /// PoseidonCompress(x) = Truncate(PoseidonPermute(x) + x) /// -/// This function works generically over `A: Algebra`, allowing it to process both: +/// This function works generically over `R: PrimeCharacteristicRing`, allowing it to process both: /// - Scalar fields, /// - Packed SIMD fields /// @@ -81,7 +81,7 @@ impl PoseidonTweak { /// /// - `WIDTH`: total state width (input length to permutation). /// - `OUT_LEN`: number of output elements to return. -/// - `perm`: a Poseidon permutation over `[A; WIDTH]`. +/// - `perm`: a cryptographically secure Poseidon permutation over `[R; WIDTH]`. /// - `input`: slice of input values, must be `≤ WIDTH` and `≥ OUT_LEN`. /// /// ### Warning: Input Padding @@ -96,13 +96,13 @@ impl PoseidonTweak { /// Panics: /// - If `input.len() < OUT_LEN` /// - If `OUT_LEN > WIDTH` -pub fn poseidon_compress( +pub fn poseidon_compress( perm: &P, - input: &[A], -) -> [A; OUT_LEN] + input: &[R], +) -> [R; OUT_LEN] where - A: Algebra + Copy, - P: Permutation<[A; WIDTH]>, + R: PrimeCharacteristicRing + Copy, + P: CryptographicPermutation<[R; WIDTH]>, { assert!( input.len() >= OUT_LEN, @@ -110,7 +110,7 @@ where ); // Copy the input into a fixed-width buffer, zero-padding unused elements if any. - let mut padded_input = [A::ZERO; WIDTH]; + let mut padded_input = [R::ZERO; WIDTH]; padded_input[..input.len()].copy_from_slice(input); // Start with the input as the initial state. @@ -149,7 +149,7 @@ fn poseidon_safe_domain_separator [A; OUT_LEN] where A: Algebra + Copy, - P: Permutation<[A; WIDTH]>, + P: CryptographicPermutation<[A; WIDTH]>, { // Combine params into a single number in base 2^32 // @@ -204,7 +204,7 @@ fn poseidon_sponge( ) -> [A; OUT_LEN] where A: Algebra + Copy, - P: Permutation<[A; WIDTH]>, + P: CryptographicPermutation<[A; WIDTH]>, { // The capacity length must be strictly smaller than the width to have a non-zero rate. // This check prevents a panic from subtraction underflow when calculating the rate. From be15969fd71aae462effd75957570fafb02b2b6c Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 20 Nov 2025 09:55:24 +0100 Subject: [PATCH 7/9] add unit tests against compute_tree_leaves_naive --- src/symmetric/tweak_hash/poseidon.rs | 167 +++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 30a93e6..62294c1 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -632,6 +632,8 @@ mod tests { use num_bigint::BigUint; use rand::Rng; + use crate::symmetric::prf::shake_to_field::ShakePRFtoF; + use super::*; #[test] @@ -1021,4 +1023,169 @@ mod tests { } } } + + /// Naive/scalar implementation of compute_tree_leaves for testing purposes. + fn compute_tree_leaves_naive< + TH: TweakableHash, + PRF: Pseudorandom, + const PARAMETER_LEN: usize, + const HASH_LEN: usize, + const TWEAK_LEN: usize, + const CAPACITY: usize, + const NUM_CHUNKS: usize, + >( + prf_key: &PRF::Key, + parameter: &TH::Parameter, + epochs: &[u32], + num_chains: usize, + chain_length: usize, + ) -> Vec + where + PRF::Domain: Into, + { + // Process each epoch in parallel + epochs + .iter() + .map(|&epoch| { + // For each epoch, walk all chains in parallel + let chain_ends = (0..num_chains) + .into_iter() + .map(|chain_index| { + // Each chain start is just a PRF evaluation + let start = + PRF::get_domain_element(prf_key, epoch, chain_index as u64).into(); + // Walk the chain to get the public chain end + chain::( + parameter, + epoch, + chain_index as u8, + 0, + chain_length - 1, + &start, + ) + }) + .collect::>(); + // Build hash of chain ends / public keys + TH::apply(parameter, &TH::tree_tweak(0, epoch), &chain_ends) + }) + .collect() + } + + #[test] + fn test_compute_tree_leaves_matches_naive() { + type TestPRF = ShakePRFtoF<4, 4>; + type TestTH = PoseidonTweak44; + + let mut rng = rand::rng(); + + // Generate test parameters + let prf_key = TestPRF::key_gen(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + // Test with different numbers of epochs to cover both SIMD and remainder paths + let test_cases = vec![ + // Small cases that fit in one SIMD batch + vec![0, 1, 2, 3], + // Exact multiple of SIMD width (assuming width is typically 4, 8, or 16) + vec![0, 1, 2, 3, 4, 5, 6, 7], + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + // Non-multiple of SIMD width to test remainder handling + vec![0, 1, 2, 3, 4, 5], + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + ]; + + let num_chains = 128; + let chain_length = 10; + + for epochs in test_cases { + // Compute using SIMD implementation + let simd_result = TestTH::compute_tree_leaves::( + &prf_key, + ¶meter, + &epochs, + num_chains, + chain_length, + ); + + // Compute using naive/scalar implementation + let naive_result = compute_tree_leaves_naive::( + &prf_key, + ¶meter, + &epochs, + num_chains, + chain_length, + ); + + // Results should match exactly + assert_eq!( + simd_result.len(), + naive_result.len(), + "SIMD and naive implementations produced different number of leaves for epochs {:?}", + epochs + ); + + for (i, (simd_leaf, naive_leaf)) in + simd_result.iter().zip(naive_result.iter()).enumerate() + { + assert_eq!( + simd_leaf, naive_leaf, + "Mismatch at epoch index {} (epoch {}): SIMD and naive implementations produced different results", + i, epochs[i] + ); + } + } + } + + #[test] + fn test_compute_tree_leaves_matches_naive_random_epochs() { + type TestPRF = ShakePRFtoF<4, 4>; + type TestTH = PoseidonTweak44; + + let mut rng = rand::rng(); + + // Generate test parameters + let prf_key = TestPRF::key_gen(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + let num_chains = 128; + let chain_length = 10; + + // Test with random epochs (not necessarily sequential) + let random_epochs: Vec = (0..17).map(|_| rng.random::() % 1000).collect(); + + // Compute using SIMD implementation + let simd_result = TestTH::compute_tree_leaves::( + &prf_key, + ¶meter, + &random_epochs, + num_chains, + chain_length, + ); + + // Compute using naive/scalar implementation + let naive_result = compute_tree_leaves_naive::( + &prf_key, + ¶meter, + &random_epochs, + num_chains, + chain_length, + ); + + // Results should match exactly + assert_eq!( + simd_result.len(), + naive_result.len(), + "SIMD and naive implementations produced different number of leaves" + ); + + for (i, (simd_leaf, naive_leaf)) in simd_result.iter().zip(naive_result.iter()).enumerate() + { + assert_eq!( + simd_leaf, naive_leaf, + "Mismatch at epoch index {} (epoch {}): SIMD and naive implementations produced different results", + i, random_epochs[i] + ); + } + } } From e0040f1c80a6bb7dba2946fab656a03fb132d330 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 20 Nov 2025 09:57:00 +0100 Subject: [PATCH 8/9] clippy --- src/symmetric/tweak_hash/poseidon.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index 62294c1..ed242e4 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -1049,7 +1049,6 @@ mod tests { .map(|&epoch| { // For each epoch, walk all chains in parallel let chain_ends = (0..num_chains) - .into_iter() .map(|chain_index| { // Each chain start is just a PRF evaluation let start = From 68c7234044f8ca6c899d0b31ed0ed940b03f987e Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 20 Nov 2025 10:05:58 +0100 Subject: [PATCH 9/9] mv simd_utils to root --- src/lib.rs | 1 + src/{symmetric => }/simd_utils.rs | 0 src/symmetric.rs | 1 - src/symmetric/tweak_hash/poseidon.rs | 2 +- 4 files changed, 2 insertions(+), 2 deletions(-) rename src/{symmetric => }/simd_utils.rs (100%) diff --git a/src/lib.rs b/src/lib.rs index e29172e..53f51b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub(crate) type PackedF = ::Packing; pub(crate) mod hypercube; pub(crate) mod inc_encoding; pub mod signature; +pub(crate) mod simd_utils; pub(crate) mod symmetric; // Cached Poseidon2 permutations. diff --git a/src/symmetric/simd_utils.rs b/src/simd_utils.rs similarity index 100% rename from src/symmetric/simd_utils.rs rename to src/simd_utils.rs diff --git a/src/symmetric.rs b/src/symmetric.rs index 2f4e422..c688b38 100644 --- a/src/symmetric.rs +++ b/src/symmetric.rs @@ -1,5 +1,4 @@ pub mod message_hash; pub mod prf; -pub mod simd_utils; pub mod tweak_hash; pub mod tweak_hash_tree; diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index ed242e4..a584907 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -9,8 +9,8 @@ use crate::TWEAK_SEPARATOR_FOR_CHAIN_HASH; use crate::TWEAK_SEPARATOR_FOR_TREE_HASH; use crate::poseidon2_16; use crate::poseidon2_24; +use crate::simd_utils::{pack_array, unpack_array}; use crate::symmetric::prf::Pseudorandom; -use crate::symmetric::simd_utils::{pack_array, unpack_array}; use crate::symmetric::tweak_hash::chain; use crate::{F, PackedF};