Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 120 additions & 5 deletions src/simd_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::array;

use p3_field::PackedValue;

use crate::{PackedF, array::FieldArray};
use crate::{F, PackedF, array::FieldArray};

/// Packs scalar arrays into SIMD-friendly vertical layout.
///
Expand Down Expand Up @@ -55,14 +55,67 @@ pub fn pack_array<const N: usize>(data: &[FieldArray<N>]) -> [PackedF; N] {
/// ```
#[inline]
pub fn unpack_array<const N: usize>(packed_data: &[PackedF; N], output: &mut [FieldArray<N>]) {
for (i, data) in packed_data.iter().enumerate().take(N) {
let unpacked_v = data.as_slice();
for j in 0..PackedF::WIDTH {
output[j][i] = unpacked_v[j];
// Optimized for cache locality: iterate over output lanes first
#[allow(clippy::needless_range_loop)]
for j in 0..PackedF::WIDTH {
for i in 0..N {
output[j].0[i] = packed_data[i].as_slice()[j];
}
}
}

/// Pack even-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[0], data[2], data[4], ...` into `dest[offset..offset+N]`.
/// Useful for packing left children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline]
pub fn pack_even_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane][i]);
}
}

/// Pack odd-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[1], data[3], data[5], ...` into `dest[offset..offset+N]`.
/// Useful for packing right children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline]
pub fn pack_odd_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane + 1][i]);
}
}

/// Pack values generated by a function directly into destination.
///
/// For each element index `i` in `0..N`, generates a PackedF by calling
/// `f(i, lane)` for each SIMD lane.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `f` - Function that takes (element_index, lane_index) and returns a field element
#[inline]
pub fn pack_fn_into<const N: usize>(
dest: &mut [PackedF],
offset: usize,
f: impl Fn(usize, usize) -> F,
) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| f(i, lane));
}
}

#[cfg(test)]
mod tests {
use crate::F;
Expand Down Expand Up @@ -176,5 +229,67 @@ mod tests {
// Verify roundtrip
prop_assert_eq!(original, unpacked);
}

#[test]
fn proptest_pack_even_odd_into(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate interleaved pairs: [L0, R0, L1, R1, ...]
let pairs: [FieldArray<5>; 2 * PackedF::WIDTH] = array::from_fn(|_| {
FieldArray(array::from_fn(|_| rng.random()))
});

// Pack even (left children) and odd (right children)
let mut dest = [PackedF::ZERO; 12];
pack_even_into(&mut dest, 1, &pairs);
pack_odd_into(&mut dest, 6, &pairs);

// Verify even indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[1 + i].as_slice()[lane],
pairs[2 * lane][i],
"Even packing mismatch at element {}, lane {}", i, lane
);
}
}

// Verify odd indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[6 + i].as_slice()[lane],
pairs[2 * lane + 1][i],
"Odd packing mismatch at element {}, lane {}", i, lane
);
}
}
}

#[test]
fn proptest_pack_fn_into(
_seed in any::<u64>()
) {
// Pack using a function that generates predictable values
let mut dest = [PackedF::ZERO; 8];
pack_fn_into::<4>(&mut dest, 3, |elem_idx, lane_idx| {
F::from_u64((elem_idx * 100 + lane_idx) as u64)
});

// Verify
for i in 0..4 {
for lane in 0..PackedF::WIDTH {
let expected = F::from_u64((i * 100 + lane) as u64);
prop_assert_eq!(
dest[3 + i].as_slice()[lane],
expected,
"pack_fn_into mismatch at element {}, lane {}", i, lane
);
}
}
}
}
}
41 changes: 41 additions & 0 deletions src/symmetric/tweak_hash.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use rand::Rng;

use rayon::prelude::*;

use crate::serialization::Serializable;
use crate::symmetric::prf::Pseudorandom;

Expand Down Expand Up @@ -46,6 +48,45 @@ pub trait TweakableHash {
message: &[Self::Domain],
) -> Self::Domain;

/// Computes one layer of a Merkle tree by hashing pairs of children into parents.
///
/// Consecutive pairs of child nodes produce their parent node by hashing
/// `(children[2*i], children[2*i+1])`. Each hash application uses a unique
/// tweak derived from the tree level and position.
///
/// # Arguments
/// * `parameter` - Public parameter for the hash function
/// * `level` - Tree level of the *parent* nodes being computed. NOTE: callers
/// need to pass `level + 1` where `level` is the children's level, since
/// tree levels are numbered from leaves (level 0) upward.
/// * `parent_start` - Starting index of the first parent in this layer, used
/// for computing position-dependent tweaks
/// * `children` - Slice of child nodes to hash pairwise (length must be even)
///
/// # Returns
/// A vector of parent nodes with length `children.len() / 2`.
///
/// This default implementation processes pairs in parallel using Rayon.
/// The Poseidon implementation overrides this with a SIMD-accelerated variant.
fn compute_tree_layer(
parameter: &Self::Parameter,
level: u8,
parent_start: usize,
children: &[Self::Domain],
) -> Vec<Self::Domain> {
// default implementation is scalar. tweak_tree/poseidon.rs provides a SIMD variant
children
.par_chunks_exact(2)
.enumerate()
.map(|(i, children)| {
// Parent index in this layer
let parent_pos = (parent_start + i) as u32;
// Hash children into their parent using the tweak
Self::apply(parameter, &Self::tree_tweak(level, parent_pos), children)
})
.collect()
}

/// Computes bottom tree leaves by walking hash chains for multiple epochs.
///
/// This method has a default scalar implementation that processes epochs in parallel.
Expand Down
Loading