Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 120 additions & 5 deletions src/simd_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::array;

use p3_field::PackedValue;

use crate::{PackedF, array::FieldArray};
use crate::{F, PackedF, array::FieldArray};

/// Packs scalar arrays into SIMD-friendly vertical layout.
///
Expand Down Expand Up @@ -55,14 +55,67 @@ pub fn pack_array<const N: usize>(data: &[FieldArray<N>]) -> [PackedF; N] {
/// ```
#[inline]
pub fn unpack_array<const N: usize>(packed_data: &[PackedF; N], output: &mut [FieldArray<N>]) {
for (i, data) in packed_data.iter().enumerate().take(N) {
let unpacked_v = data.as_slice();
for j in 0..PackedF::WIDTH {
output[j][i] = unpacked_v[j];
// Optimized for cache locality: iterate over output lanes first
#[allow(clippy::needless_range_loop)]
for j in 0..PackedF::WIDTH {
for i in 0..N {
output[j].0[i] = packed_data[i].as_slice()[j];
}
}
}

/// Pack even-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[0], data[2], data[4], ...` into `dest[offset..offset+N]`.
/// Useful for packing left children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline]
pub fn pack_even_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane][i]);
}
}

/// Pack odd-indexed FieldArrays (stride 2) directly into destination.
///
/// Packs `data[1], data[3], data[5], ...` into `dest[offset..offset+N]`.
/// Useful for packing right children from interleaved `[L0, R0, L1, R1, ...]` pairs.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `data` - Source slice of interleaved pairs (must have length >= 2 * WIDTH)
#[inline]
pub fn pack_odd_into<const N: usize>(dest: &mut [PackedF], offset: usize, data: &[FieldArray<N>]) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| data[2 * lane + 1][i]);
}
}

/// Pack values generated by a function directly into destination.
///
/// For each element index `i` in `0..N`, generates a PackedF by calling
/// `f(i, lane)` for each SIMD lane.
///
/// # Arguments
/// * `dest` - Destination slice to pack into
/// * `offset` - Starting index in `dest`
/// * `f` - Function that takes (element_index, lane_index) and returns a field element
#[inline]
pub fn pack_fn_into<const N: usize>(
dest: &mut [PackedF],
offset: usize,
f: impl Fn(usize, usize) -> F,
) {
for i in 0..N {
dest[offset + i] = PackedF::from_fn(|lane| f(i, lane));
}
}

#[cfg(test)]
mod tests {
use crate::F;
Expand Down Expand Up @@ -176,5 +229,67 @@ mod tests {
// Verify roundtrip
prop_assert_eq!(original, unpacked);
}

#[test]
fn proptest_pack_even_odd_into(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate interleaved pairs: [L0, R0, L1, R1, ...]
let pairs: [FieldArray<5>; 2 * PackedF::WIDTH] = array::from_fn(|_| {
FieldArray(array::from_fn(|_| rng.random()))
});

// Pack even (left children) and odd (right children)
let mut dest = [PackedF::ZERO; 12];
pack_even_into(&mut dest, 1, &pairs);
pack_odd_into(&mut dest, 6, &pairs);

// Verify even indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[1 + i].as_slice()[lane],
pairs[2 * lane][i],
"Even packing mismatch at element {}, lane {}", i, lane
);
}
}

// Verify odd indices were packed correctly
for i in 0..5 {
for lane in 0..PackedF::WIDTH {
prop_assert_eq!(
dest[6 + i].as_slice()[lane],
pairs[2 * lane + 1][i],
"Odd packing mismatch at element {}, lane {}", i, lane
);
}
}
}

#[test]
fn proptest_pack_fn_into(
_seed in any::<u64>()
) {
// Pack using a function that generates predictable values
let mut dest = [PackedF::ZERO; 8];
pack_fn_into::<4>(&mut dest, 3, |elem_idx, lane_idx| {
F::from_u64((elem_idx * 100 + lane_idx) as u64)
});

// Verify
for i in 0..4 {
for lane in 0..PackedF::WIDTH {
let expected = F::from_u64((i * 100 + lane) as u64);
prop_assert_eq!(
dest[3 + i].as_slice()[lane],
expected,
"pack_fn_into mismatch at element {}, lane {}", i, lane
);
}
}
}
}
}
26 changes: 26 additions & 0 deletions src/symmetric/tweak_hash.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use rand::Rng;

use rayon::prelude::*;

use crate::serialization::Serializable;
use crate::symmetric::prf::Pseudorandom;

Expand Down Expand Up @@ -46,6 +48,30 @@ pub trait TweakableHash {
message: &[Self::Domain],
) -> Self::Domain;

/// Applies the calculation for a single tweak hash tree layer.
fn compute_tree_layer(
parameter: &Self::Parameter,
level: u8,
parent_start: usize,
children: &[Self::Domain],
) -> Vec<Self::Domain> {
// default implementation is scalar. tweak_tree/poseidon.rs provides a SIMD variant
children
.par_chunks_exact(2)
.enumerate()
.map(|(i, children)| {
// Parent index in this layer
let parent_pos = (parent_start + i) as u32;
// Hash children into their parent using the tweak
Self::apply(
parameter,
&Self::tree_tweak(level + 1, parent_pos),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this level+1 correct? If I compare with the old code and unpack the function call, this would mean that we actually add 1 to the level twice now, whereas we did only once in the old code? Maybe this did not show up in tests as this default implementation is not used. Can you maybe add some test that also checks this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checked and above you seem to have just level in the scalar implementation that you use for testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, great catch! Yeah, I copied the code for the default implementation before I decided to just change the API to pass level + 1 directly and didn't update the default.

children,
)
})
.collect()
}

/// Computes bottom tree leaves by walking hash chains for multiple epochs.
///
/// This method has a default scalar implementation that processes epochs in parallel.
Expand Down
Loading