Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ dashmap = "6.1.0"
serde = { version = "1.0", features = ["derive", "alloc"] }
thiserror = "2.0"

p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }

[dev-dependencies]
criterion = "0.7"
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use p3_field::Field;
use p3_koala_bear::{
KoalaBear, Poseidon2KoalaBear, default_koalabear_poseidon2_16, default_koalabear_poseidon2_24,
};
Expand All @@ -11,6 +12,7 @@ pub const TWEAK_SEPARATOR_FOR_TREE_HASH: u8 = 0x01;
pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00;

type F = KoalaBear;
pub(crate) type PackedF = <F as Field>::Packing;

pub(crate) mod hypercube;
pub(crate) mod inc_encoding;
Expand Down
37 changes: 7 additions & 30 deletions src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,38 +205,15 @@ where
let chain_length = IE::BASE;

// the range of epochs covered by that bottom tree
let epoch_range_start = bottom_tree_index * leafs_per_bottom_tree;
let epoch_range_end = epoch_range_start + leafs_per_bottom_tree;
let epoch_range = epoch_range_start..epoch_range_end;

// parallelize the chain ends hash computation for each epoch in the interval for that bottom tree
let chain_ends_hashes = epoch_range
.into_par_iter()
.map(|epoch| {
// each epoch has a number of chains
// parallelize the chain ends computation for each chain
let chain_ends = (0..num_chains)
.into_par_iter()
.map(|chain_index| {
// each chain start is just a PRF evaluation
let start =
PRF::get_domain_element(prf_key, epoch as u32, chain_index as u64).into();
// walk the chain to get the public chain end
chain::<TH>(
parameter,
epoch as u32,
chain_index as u8,
0,
chain_length - 1,
&start,
)
})
.collect::<Vec<_>>();
// build hash of chain ends / public keys
TH::apply(parameter, &TH::tree_tweak(0, epoch as u32), &chain_ends)
})
let epoch_start = bottom_tree_index * leafs_per_bottom_tree;
let epochs: Vec<u32> = (epoch_start..epoch_start + leafs_per_bottom_tree)
.map(|e| e as u32)
.collect();

// Compute chain ends for all epochs.
let chain_ends_hashes =
TH::compute_tree_leaves::<PRF>(prf_key, parameter, &epochs, num_chains, chain_length);

// now that we have the hashes of all chain ends (= leafs of our tree), we can compute the bottom tree
HashSubTree::new_bottom_tree(
LOG_LIFETIME,
Expand Down
1 change: 1 addition & 0 deletions src/symmetric.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod message_hash;
pub mod prf;
pub mod simd_utils;
pub mod tweak_hash;
pub mod tweak_hash_tree;
2 changes: 1 addition & 1 deletion src/symmetric/message_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ where
.copied()
.collect();

let hash_fe = poseidon_compress::<_, 24, HASH_LEN_FE>(&perm, &combined_input_vec);
let hash_fe = poseidon_compress::<F, _, 24, HASH_LEN_FE>(&perm, &combined_input_vec);

// decode field elements into chunks and return them
decode_to_chunks::<DIMENSION, BASE, HASH_LEN_FE>(&hash_fe).to_vec()
Expand Down
2 changes: 1 addition & 1 deletion src/symmetric/message_hash/top_level_poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ where
.collect();

let iteration_pos_output =
poseidon_compress::<_, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);
poseidon_compress::<F, _, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);

pos_outputs[i * POS_OUTPUT_LEN_PER_INV_FE..(i + 1) * POS_OUTPUT_LEN_PER_INV_FE]
.copy_from_slice(&iteration_pos_output);
Expand Down
178 changes: 178 additions & 0 deletions src/symmetric/simd_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use core::array;

use p3_field::PackedValue;

use crate::{F, PackedF};

/// Packs scalar arrays into SIMD-friendly vertical layout.
///
/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`.
///
/// Input layout (horizontal): each row is one complete array
/// ```text
/// data[0] = [a0, a1, a2, ..., aN]
/// data[1] = [b0, b1, b2, ..., bN]
/// data[2] = [c0, c1, c2, ..., cN]
/// ...
/// ```
///
/// Output layout (vertical): each PackedF holds one element from each array
/// ```text
/// result[0] = PackedF([a0, b0, c0, ...]) // All first elements
/// result[1] = PackedF([a1, b1, c1, ...]) // All second elements
/// result[2] = PackedF([a2, b2, c2, ...]) // All third elements
/// ...
/// ```
///
/// This vertical packing enables efficient SIMD operations where a single instruction
/// processes the same element position across multiple arrays simultaneously.
#[inline]
pub fn pack_array<const N: usize>(data: &[[F; N]]) -> [PackedF; N] {
array::from_fn(|i| PackedF::from_fn(|j| data[j][i]))
}

/// Unpacks SIMD vertical layout back into scalar arrays.
///
/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`.
///
/// This is the inverse operation of `pack_array`. The output buffer must be preallocated
/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`.
///
/// Input layout (vertical): each PackedF holds one element from each array
/// ```text
/// packed_data[0] = PackedF([a0, b0, c0, ...])
/// packed_data[1] = PackedF([a1, b1, c1, ...])
/// packed_data[2] = PackedF([a2, b2, c2, ...])
/// ...
/// ```
///
/// Output layout (horizontal): each row is one complete array
/// ```text
/// output[0] = [a0, a1, a2, ..., aN]
/// output[1] = [b0, b1, b2, ..., bN]
/// output[2] = [c0, c1, c2, ..., cN]
/// ...
/// ```
#[inline]
pub fn unpack_array<const N: usize>(packed_data: &[PackedF; N], output: &mut [[F; N]]) {
for (i, data) in packed_data.iter().enumerate().take(N) {
let unpacked_v = data.as_slice();
for j in 0..PackedF::WIDTH {
output[j][i] = unpacked_v[j];
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use p3_field::PrimeCharacteristicRing;
use proptest::prelude::*;
use rand::Rng;

#[test]
fn test_pack_array_simple() {
// Test with N=2 (2 field elements per array)
// Create WIDTH arrays of [F; 2]
let data: [[F; 2]; PackedF::WIDTH] =
array::from_fn(|i| [F::from_u64(i as u64), F::from_u64((i + 100) as u64)]);

let packed = pack_array(&data);

// Check that packed[0] contains all first elements
for (lane, &expected) in data.iter().enumerate() {
assert_eq!(packed[0].as_slice()[lane], expected[0]);
}

// Check that packed[1] contains all second elements
for (lane, &expected) in data.iter().enumerate() {
assert_eq!(packed[1].as_slice()[lane], expected[1]);
}
}

#[test]
fn test_unpack_array_simple() {
// Create packed data
let packed: [PackedF; 2] = [
PackedF::from_fn(|i| F::from_u64(i as u64)),
PackedF::from_fn(|i| F::from_u64((i + 100) as u64)),
];

// Unpack
let mut output = [[F::ZERO; 2]; PackedF::WIDTH];
unpack_array(&packed, &mut output);

// Verify
for (lane, arr) in output.iter().enumerate() {
assert_eq!(arr[0], F::from_u64(lane as u64));
assert_eq!(arr[1], F::from_u64((lane + 100) as u64));
}
}

#[test]
fn test_pack_preserves_element_order() {
// Create data where each array has sequential values
let data: [[F; 3]; PackedF::WIDTH] = array::from_fn(|i| {
[
F::from_u64((i * 3) as u64),
F::from_u64((i * 3 + 1) as u64),
F::from_u64((i * 3 + 2) as u64),
]
});

let packed = pack_array(&data);

// Verify the packing structure
// packed[0] should contain: [0, 3, 6, 9, ...]
// packed[1] should contain: [1, 4, 7, 10, ...]
// packed[2] should contain: [2, 5, 8, 11, ...]
for (element_idx, p) in packed.iter().enumerate() {
for lane in 0..PackedF::WIDTH {
let expected = F::from_u64((lane * 3 + element_idx) as u64);
assert_eq!(p.as_slice()[lane], expected);
}
}
}

#[test]
fn test_unpack_preserves_element_order() {
// Create packed data with known pattern
let packed: [PackedF; 3] = [
PackedF::from_fn(|i| F::from_u64((i * 3) as u64)),
PackedF::from_fn(|i| F::from_u64((i * 3 + 1) as u64)),
PackedF::from_fn(|i| F::from_u64((i * 3 + 2) as u64)),
];

let mut output = [[F::ZERO; 3]; PackedF::WIDTH];
unpack_array(&packed, &mut output);

// Verify each array has sequential values
for (lane, arr) in output.iter().enumerate() {
assert_eq!(arr[0], F::from_u64((lane * 3) as u64));
assert_eq!(arr[1], F::from_u64((lane * 3 + 1) as u64));
assert_eq!(arr[2], F::from_u64((lane * 3 + 2) as u64));
}
}

proptest! {
#[test]
fn proptest_pack_unpack_roundtrip(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate random data with N=10
let original: [[F; 10]; PackedF::WIDTH] = array::from_fn(|_| {
array::from_fn(|_| rng.random())
});

// Pack and unpack
let packed = pack_array(&original);
let mut unpacked = [[F::ZERO; 10]; PackedF::WIDTH];
unpack_array(&packed, &mut unpacked);

// Verify roundtrip
prop_assert_eq!(original, unpacked);
}
}
}
28 changes: 25 additions & 3 deletions src/symmetric/tweak_hash.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use rand::Rng;
use serde::{Serialize, de::DeserializeOwned};

use crate::symmetric::prf::Pseudorandom;

/// Trait to model a tweakable hash function.
/// Such a function takes a public parameter, a tweak, and a
/// message to be hashed. The tweak should be understood as an
Expand All @@ -14,8 +16,13 @@ use serde::{Serialize, de::DeserializeOwned};
/// to obtain distinct tweaks for applications in chains and
/// applications in Merkle trees.
pub trait TweakableHash {
/// Public parameter type for the hash function
type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned;

/// Tweak type for domain separation
type Tweak;

/// Domain element type (defines output and input types to the hash)
type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned;

/// Generates a random public parameter.
Expand All @@ -39,8 +46,24 @@ pub trait TweakableHash {
message: &[Self::Domain],
) -> Self::Domain;

/// Function to check internal consistency of any given parameters
/// For testing only, and expected to panic if something is wrong.
/// Computes bottom tree leaves by walking hash chains for multiple epochs.
///
/// This method has a default scalar implementation that processes epochs in parallel.
fn compute_tree_leaves<PRF>(
prf_key: &PRF::Key,
parameter: &Self::Parameter,
epochs: &[u32],
num_chains: usize,
chain_length: usize,
) -> Vec<Self::Domain>
where
PRF: Pseudorandom,
PRF::Domain: Into<Self::Domain>,
Self: Sized;

/// Function to check internal consistency of any given parameters.
///
/// This is for testing only and is expected to panic if something is wrong.
#[cfg(test)]
fn internal_consistency_check();
}
Expand Down Expand Up @@ -77,7 +100,6 @@ pub mod poseidon;

#[cfg(test)]
mod tests {

use crate::symmetric::tweak_hash::poseidon::PoseidonTweak44;

use super::*;
Expand Down
Loading