Skip to content

Commit 1a23b5c

Browse files
committed
simd: apply packing for tree leaves
1 parent 2bd2b42 commit 1a23b5c

File tree

7 files changed

+397
-73
lines changed

7 files changed

+397
-73
lines changed

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ dashmap = "6.1.0"
3939
serde = { version = "1.0", features = ["derive", "alloc"] }
4040
thiserror = "2.0"
4141

42-
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
43-
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
44-
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
45-
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "2117e4b" }
42+
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
43+
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
44+
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
45+
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
4646

4747
[dev-dependencies]
4848
criterion = "0.7"

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use p3_field::Field;
12
use p3_koala_bear::{
23
KoalaBear, Poseidon2KoalaBear, default_koalabear_poseidon2_16, default_koalabear_poseidon2_24,
34
};
@@ -11,6 +12,7 @@ pub const TWEAK_SEPARATOR_FOR_TREE_HASH: u8 = 0x01;
1112
pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00;
1213

1314
type F = KoalaBear;
15+
pub(crate) type PackedF = <F as Field>::Packing;
1416

1517
pub(crate) mod hypercube;
1618
pub(crate) mod inc_encoding;

src/signature/generalized_xmss.rs

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -205,38 +205,15 @@ where
205205
let chain_length = IE::BASE;
206206

207207
// the range of epochs covered by that bottom tree
208-
let epoch_range_start = bottom_tree_index * leafs_per_bottom_tree;
209-
let epoch_range_end = epoch_range_start + leafs_per_bottom_tree;
210-
let epoch_range = epoch_range_start..epoch_range_end;
211-
212-
// parallelize the chain ends hash computation for each epoch in the interval for that bottom tree
213-
let chain_ends_hashes = epoch_range
214-
.into_par_iter()
215-
.map(|epoch| {
216-
// each epoch has a number of chains
217-
// parallelize the chain ends computation for each chain
218-
let chain_ends = (0..num_chains)
219-
.into_par_iter()
220-
.map(|chain_index| {
221-
// each chain start is just a PRF evaluation
222-
let start =
223-
PRF::get_domain_element(prf_key, epoch as u32, chain_index as u64).into();
224-
// walk the chain to get the public chain end
225-
chain::<TH>(
226-
parameter,
227-
epoch as u32,
228-
chain_index as u8,
229-
0,
230-
chain_length - 1,
231-
&start,
232-
)
233-
})
234-
.collect::<Vec<_>>();
235-
// build hash of chain ends / public keys
236-
TH::apply(parameter, &TH::tree_tweak(0, epoch as u32), &chain_ends)
237-
})
208+
let epoch_start = bottom_tree_index * leafs_per_bottom_tree;
209+
let epochs: Vec<u32> = (epoch_start..epoch_start + leafs_per_bottom_tree)
210+
.map(|e| e as u32)
238211
.collect();
239212

213+
// Compute chain ends for all epochs.
214+
let chain_ends_hashes =
215+
TH::compute_tree_leaves::<PRF>(prf_key, parameter, &epochs, num_chains, chain_length);
216+
240217
// now that we have the hashes of all chain ends (= leafs of our tree), we can compute the bottom tree
241218
HashSubTree::new_bottom_tree(
242219
LOG_LIFETIME,

src/symmetric/message_hash/poseidon.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ where
164164
.copied()
165165
.collect();
166166

167-
let hash_fe = poseidon_compress::<_, 24, HASH_LEN_FE>(&perm, &combined_input_vec);
167+
let hash_fe = poseidon_compress::<F, _, 24, HASH_LEN_FE>(&perm, &combined_input_vec);
168168

169169
// decode field elements into chunks and return them
170170
decode_to_chunks::<DIMENSION, BASE, HASH_LEN_FE>(&hash_fe).to_vec()

src/symmetric/message_hash/top_level_poseidon.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ where
159159
.collect();
160160

161161
let iteration_pos_output =
162-
poseidon_compress::<_, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);
162+
poseidon_compress::<F, _, 24, POS_OUTPUT_LEN_PER_INV_FE>(&perm, &combined_input);
163163

164164
pos_outputs[i * POS_OUTPUT_LEN_PER_INV_FE..(i + 1) * POS_OUTPUT_LEN_PER_INV_FE]
165165
.copy_from_slice(&iteration_pos_output);

src/symmetric/tweak_hash.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use rand::Rng;
2+
use rayon::prelude::*;
23
use serde::{Serialize, de::DeserializeOwned};
34

5+
use crate::symmetric::prf::Pseudorandom;
6+
47
/// Trait to model a tweakable hash function.
58
/// Such a function takes a public parameter, a tweak, and a
69
/// message to be hashed. The tweak should be understood as an
@@ -14,8 +17,13 @@ use serde::{Serialize, de::DeserializeOwned};
1417
/// to obtain distinct tweaks for applications in chains and
1518
/// applications in Merkle trees.
1619
pub trait TweakableHash {
20+
/// Public parameter type for the hash function
1721
type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned;
22+
23+
/// Tweak type for domain separation
1824
type Tweak;
25+
26+
/// Domain element type (defines output and input types to the hash)
1927
type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned;
2028

2129
/// Generates a random public parameter.
@@ -39,8 +47,50 @@ pub trait TweakableHash {
3947
message: &[Self::Domain],
4048
) -> Self::Domain;
4149

42-
/// Function to check internal consistency of any given parameters
43-
/// For testing only, and expected to panic if something is wrong.
50+
/// Computes bottom tree leaves by walking hash chains for multiple epochs.
51+
///
52+
/// This method has a default scalar implementation that processes epochs in parallel.
53+
fn compute_tree_leaves<PRF>(
54+
prf_key: &PRF::Key,
55+
parameter: &Self::Parameter,
56+
epochs: &[u32],
57+
num_chains: usize,
58+
chain_length: usize,
59+
) -> Vec<Self::Domain>
60+
where
61+
PRF: Pseudorandom,
62+
PRF::Domain: Into<Self::Domain>,
63+
Self: Sized,
64+
{
65+
// Default scalar implementation: process each epoch in parallel
66+
epochs
67+
.par_iter()
68+
.map(|&epoch| {
69+
// For each epoch, walk all chains in parallel
70+
let chain_ends: Vec<_> = (0..num_chains)
71+
.into_par_iter()
72+
.map(|chain_index| {
73+
let start =
74+
PRF::get_domain_element(prf_key, epoch, chain_index as u64).into();
75+
chain::<Self>(
76+
parameter,
77+
epoch,
78+
chain_index as u8,
79+
0,
80+
chain_length - 1,
81+
&start,
82+
)
83+
})
84+
.collect();
85+
// Hash all chain ends together to get the leaf
86+
Self::apply(parameter, &Self::tree_tweak(0, epoch), &chain_ends)
87+
})
88+
.collect()
89+
}
90+
91+
/// Function to check internal consistency of any given parameters.
92+
///
93+
/// This is for testing only and is expected to panic if something is wrong.
4494
#[cfg(test)]
4595
fn internal_consistency_check();
4696
}
@@ -77,7 +127,6 @@ pub mod poseidon;
77127

78128
#[cfg(test)]
79129
mod tests {
80-
81130
use crate::symmetric::tweak_hash::poseidon::PoseidonTweak44;
82131

83132
use super::*;

0 commit comments

Comments
 (0)