Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ dashmap = "6.1.0"
serde = { version = "1.0", features = ["derive", "alloc"] }
thiserror = "2.0"

p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }

[dev-dependencies]
criterion = "0.7"
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use p3_field::Field;
use p3_koala_bear::{
KoalaBear, Poseidon2KoalaBear, default_koalabear_poseidon2_16, default_koalabear_poseidon2_24,
};
Expand All @@ -11,10 +12,12 @@ pub const TWEAK_SEPARATOR_FOR_TREE_HASH: u8 = 0x01;
pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00;

type F = KoalaBear;
pub(crate) type PackedF = <F as Field>::Packing;

pub(crate) mod hypercube;
pub(crate) mod inc_encoding;
pub mod signature;
pub(crate) mod simd_utils;
pub(crate) mod symmetric;

// Cached Poseidon2 permutations.
Expand Down
37 changes: 7 additions & 30 deletions src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,38 +205,15 @@ where
let chain_length = IE::BASE;

// the range of epochs covered by that bottom tree
let epoch_range_start = bottom_tree_index * leafs_per_bottom_tree;
let epoch_range_end = epoch_range_start + leafs_per_bottom_tree;
let epoch_range = epoch_range_start..epoch_range_end;

// parallelize the chain ends hash computation for each epoch in the interval for that bottom tree
let chain_ends_hashes = epoch_range
.into_par_iter()
.map(|epoch| {
// each epoch has a number of chains
// parallelize the chain ends computation for each chain
let chain_ends = (0..num_chains)
.into_par_iter()
.map(|chain_index| {
// each chain start is just a PRF evaluation
let start =
PRF::get_domain_element(prf_key, epoch as u32, chain_index as u64).into();
// walk the chain to get the public chain end
chain::<TH>(
parameter,
epoch as u32,
chain_index as u8,
0,
chain_length - 1,
&start,
)
})
.collect::<Vec<_>>();
// build hash of chain ends / public keys
TH::apply(parameter, &TH::tree_tweak(0, epoch as u32), &chain_ends)
})
let epoch_start = bottom_tree_index * leafs_per_bottom_tree;
let epochs: Vec<u32> = (epoch_start..epoch_start + leafs_per_bottom_tree)
.map(|e| e as u32)
.collect();

// Compute chain ends for all epochs.
let chain_ends_hashes =
TH::compute_tree_leaves::<PRF>(prf_key, parameter, &epochs, num_chains, chain_length);

// now that we have the hashes of all chain ends (= leafs of our tree), we can compute the bottom tree
HashSubTree::new_bottom_tree(
LOG_LIFETIME,
Expand Down
178 changes: 178 additions & 0 deletions src/simd_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use core::array;

use p3_field::PackedValue;

use crate::{F, PackedF};

/// Packs scalar arrays into SIMD-friendly vertical layout.
///
/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`.
///
/// Input layout (horizontal): each row is one complete array
/// ```text
/// data[0] = [a0, a1, a2, ..., aN]
/// data[1] = [b0, b1, b2, ..., bN]
/// data[2] = [c0, c1, c2, ..., cN]
/// ...
/// ```
///
/// Output layout (vertical): each PackedF holds one element from each array
/// ```text
/// result[0] = PackedF([a0, b0, c0, ...]) // All first elements
/// result[1] = PackedF([a1, b1, c1, ...]) // All second elements
/// result[2] = PackedF([a2, b2, c2, ...]) // All third elements
/// ...
/// ```
///
/// This vertical packing enables efficient SIMD operations where a single instruction
/// processes the same element position across multiple arrays simultaneously.
#[inline]
pub fn pack_array<const N: usize>(data: &[[F; N]]) -> [PackedF; N] {
array::from_fn(|i| PackedF::from_fn(|j| data[j][i]))
}

/// Unpacks SIMD vertical layout back into scalar arrays.
///
/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`.
///
/// This is the inverse operation of `pack_array`. The output buffer must be preallocated
/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`.
///
/// Input layout (vertical): each PackedF holds one element from each array
/// ```text
/// packed_data[0] = PackedF([a0, b0, c0, ...])
/// packed_data[1] = PackedF([a1, b1, c1, ...])
/// packed_data[2] = PackedF([a2, b2, c2, ...])
/// ...
/// ```
///
/// Output layout (horizontal): each row is one complete array
/// ```text
/// output[0] = [a0, a1, a2, ..., aN]
/// output[1] = [b0, b1, b2, ..., bN]
/// output[2] = [c0, c1, c2, ..., cN]
/// ...
/// ```
#[inline]
pub fn unpack_array<const N: usize>(packed_data: &[PackedF; N], output: &mut [[F; N]]) {
for (i, data) in packed_data.iter().enumerate().take(N) {
let unpacked_v = data.as_slice();
for j in 0..PackedF::WIDTH {
output[j][i] = unpacked_v[j];
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use p3_field::PrimeCharacteristicRing;
use proptest::prelude::*;
use rand::Rng;

#[test]
fn test_pack_array_simple() {
// Test with N=2 (2 field elements per array)
// Create WIDTH arrays of [F; 2]
let data: [[F; 2]; PackedF::WIDTH] =
array::from_fn(|i| [F::from_u64(i as u64), F::from_u64((i + 100) as u64)]);

let packed = pack_array(&data);

// Check that packed[0] contains all first elements
for (lane, &expected) in data.iter().enumerate() {
assert_eq!(packed[0].as_slice()[lane], expected[0]);
}

// Check that packed[1] contains all second elements
for (lane, &expected) in data.iter().enumerate() {
assert_eq!(packed[1].as_slice()[lane], expected[1]);
}
}

#[test]
fn test_unpack_array_simple() {
// Create packed data
let packed: [PackedF; 2] = [
PackedF::from_fn(|i| F::from_u64(i as u64)),
PackedF::from_fn(|i| F::from_u64((i + 100) as u64)),
];

// Unpack
let mut output = [[F::ZERO; 2]; PackedF::WIDTH];
unpack_array(&packed, &mut output);

// Verify
for (lane, arr) in output.iter().enumerate() {
assert_eq!(arr[0], F::from_u64(lane as u64));
assert_eq!(arr[1], F::from_u64((lane + 100) as u64));
}
}

#[test]
fn test_pack_preserves_element_order() {
// Create data where each array has sequential values
let data: [[F; 3]; PackedF::WIDTH] = array::from_fn(|i| {
[
F::from_u64((i * 3) as u64),
F::from_u64((i * 3 + 1) as u64),
F::from_u64((i * 3 + 2) as u64),
]
});

let packed = pack_array(&data);

// Verify the packing structure
// packed[0] should contain: [0, 3, 6, 9, ...]
// packed[1] should contain: [1, 4, 7, 10, ...]
// packed[2] should contain: [2, 5, 8, 11, ...]
for (element_idx, p) in packed.iter().enumerate() {
for lane in 0..PackedF::WIDTH {
let expected = F::from_u64((lane * 3 + element_idx) as u64);
assert_eq!(p.as_slice()[lane], expected);
}
}
}

#[test]
fn test_unpack_preserves_element_order() {
// Create packed data with known pattern
let packed: [PackedF; 3] = [
PackedF::from_fn(|i| F::from_u64((i * 3) as u64)),
PackedF::from_fn(|i| F::from_u64((i * 3 + 1) as u64)),
PackedF::from_fn(|i| F::from_u64((i * 3 + 2) as u64)),
];

let mut output = [[F::ZERO; 3]; PackedF::WIDTH];
unpack_array(&packed, &mut output);

// Verify each array has sequential values
for (lane, arr) in output.iter().enumerate() {
assert_eq!(arr[0], F::from_u64((lane * 3) as u64));
assert_eq!(arr[1], F::from_u64((lane * 3 + 1) as u64));
assert_eq!(arr[2], F::from_u64((lane * 3 + 2) as u64));
}
}

proptest! {
#[test]
fn proptest_pack_unpack_roundtrip(
_seed in any::<u64>()
) {
let mut rng = rand::rng();

// Generate random data with N=10
let original: [[F; 10]; PackedF::WIDTH] = array::from_fn(|_| {
array::from_fn(|_| rng.random())
});

// Pack and unpack
let packed = pack_array(&original);
let mut unpacked = [[F::ZERO; 10]; PackedF::WIDTH];
unpack_array(&packed, &mut unpacked);

// Verify roundtrip
prop_assert_eq!(original, unpacked);
}
}
}
2 changes: 1 addition & 1 deletion src/symmetric/message_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ where
.copied()
.collect();

let hash_fe = poseidon_compress::<_, 24, HASH_LEN_FE>(&perm, &combined_input_vec);
let hash_fe = poseidon_compress::<F, _, 24, HASH_LEN_FE>(&perm, &combined_input_vec);

// decode field elements into chunks and return them
decode_to_chunks::<DIMENSION, BASE, HASH_LEN_FE>(&hash_fe).to_vec()
Expand Down
2 changes: 1 addition & 1 deletion src/symmetric/message_hash/top_level_poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ where
.collect();

let iteration_pos_output =
poseidon_compress::<_, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);
poseidon_compress::<F, _, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);

pos_outputs[i * POS_OUTPUT_LEN_PER_INV_FE..(i + 1) * POS_OUTPUT_LEN_PER_INV_FE]
.copy_from_slice(&iteration_pos_output);
Expand Down
28 changes: 25 additions & 3 deletions src/symmetric/tweak_hash.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use rand::Rng;
use serde::{Serialize, de::DeserializeOwned};

use crate::symmetric::prf::Pseudorandom;

/// Trait to model a tweakable hash function.
/// Such a function takes a public parameter, a tweak, and a
/// message to be hashed. The tweak should be understood as an
Expand All @@ -14,8 +16,13 @@ use serde::{Serialize, de::DeserializeOwned};
/// to obtain distinct tweaks for applications in chains and
/// applications in Merkle trees.
pub trait TweakableHash {
/// Public parameter type for the hash function
type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned;

/// Tweak type for domain separation
type Tweak;

/// Domain element type (defines output and input types to the hash)
type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned;

/// Generates a random public parameter.
Expand All @@ -39,8 +46,24 @@ pub trait TweakableHash {
message: &[Self::Domain],
) -> Self::Domain;

/// Function to check internal consistency of any given parameters
/// For testing only, and expected to panic if something is wrong.
/// Computes bottom tree leaves by walking hash chains for multiple epochs.
///
/// This method has a default scalar implementation that processes epochs in parallel.
fn compute_tree_leaves<PRF>(
prf_key: &PRF::Key,
parameter: &Self::Parameter,
epochs: &[u32],
num_chains: usize,
chain_length: usize,
) -> Vec<Self::Domain>
where
PRF: Pseudorandom,
PRF::Domain: Into<Self::Domain>,
Self: Sized;

/// Function to check internal consistency of any given parameters.
///
/// This is for testing only and is expected to panic if something is wrong.
#[cfg(test)]
fn internal_consistency_check();
}
Expand Down Expand Up @@ -77,7 +100,6 @@ pub mod poseidon;

#[cfg(test)]
mod tests {

use crate::symmetric::tweak_hash::poseidon::PoseidonTweak44;

use super::*;
Expand Down
Loading