From 6aa85d37e467da374901f3072b3db1d9da0fb56c Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sun, 23 Nov 2025 21:59:48 +0100 Subject: [PATCH 1/6] public key: add ssz impl --- Cargo.toml | 3 + src/signature/generalized_xmss.rs | 232 +++++++++++++++++++++++++++++- 2 files changed, 234 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0aefe4a..0986766 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,9 @@ dashmap = "6.1.0" serde = { version = "1.0", features = ["derive", "alloc"] } thiserror = "2.0" +ssz = { package = "ethereum_ssz", version = "0.10.0" } +ssz_derive = { package = "ethereum_ssz_derive", version = "0.10.0" } + p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 529ed98..a2f7f5f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -5,7 +5,7 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use crate::{ - MESSAGE_LENGTH, + F, MESSAGE_LENGTH, inc_encoding::IncomparableEncoding, signature::SignatureSchemeSecretKey, symmetric::{ @@ -17,6 +17,9 @@ use crate::{ use super::{SignatureScheme, SigningError}; +use p3_field::{PrimeField32, RawDataSerializable}; +use ssz::{Decode, DecodeError, Encode}; + /// Implementation of the generalized XMSS signature scheme /// from any incomparable encoding scheme and any tweakable hash /// @@ -525,6 +528,85 @@ where } } +impl Encode for GeneralizedXMSSPublicKey +where + TH: TweakableHash, +{ + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + } + + fn ssz_bytes_len(&self) -> usize { + (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + } + + fn ssz_append(&self, buf: &mut Vec) { + // Reserve space for the output + buf.reserve((HASH_LEN + PARAM_LEN) * F::NUM_BYTES); + + // Encode root + for elem in self.root.iter() { + let value = elem.as_canonical_u32(); + buf.extend_from_slice(&value.to_le_bytes()); + } + // Encode parameter + for elem in self.parameter.iter() { + let value = elem.as_canonical_u32(); + buf.extend_from_slice(&value.to_le_bytes()); + } + } +} + +impl Decode for GeneralizedXMSSPublicKey +where + TH: TweakableHash, +{ + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let expected_len = (HASH_LEN + PARAM_LEN) * F::NUM_BYTES; + + if bytes.len() != expected_len { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: expected_len, + }); + } + + // We know this is safe because of the length check above. + let (root_bytes, param_bytes) = bytes.split_at(HASH_LEN * F::NUM_BYTES); + + // Define a helper closure to decode an array from bytes. + let decode_array = |chunk: &[u8]| -> F { + let val = u32::from_le_bytes(chunk.try_into().unwrap()); + F::new(val) + }; + + // Construct the root and parameter arrays. + let root = std::array::from_fn(|i| { + let start = i * F::NUM_BYTES; + decode_array(&root_bytes[start..start + F::NUM_BYTES]) + }); + + let parameter = std::array::from_fn(|i| { + let start = i * F::NUM_BYTES; + decode_array(¶m_bytes[start..start + F::NUM_BYTES]) + }); + + Ok(Self { root, parameter }) + } +} + /// Instantiations of the generalized XMSS signature scheme based on Poseidon2 pub mod instantiations_poseidon; /// Instantiations of the generalized XMSS signature scheme based on the @@ -545,6 +627,13 @@ mod tests { use super::*; + use crate::symmetric::tweak_hash::poseidon::PoseidonTweakHash; + use p3_field::PrimeCharacteristicRing; + use rand::rng; + use ssz::{Decode, Encode}; + + type TestTH = PoseidonTweakHash<5, 7, 2, 9, 155>; + #[test] pub fn test_target_sum_poseidon() { // Note: do not use these parameters, they are just for testing @@ -679,4 +768,145 @@ mod tests { let (start, end_excl) = expand_activation_time::(12, 2); assert!((start == 2) && (end_excl == 4)); } + + #[test] + fn test_public_key_ssz_roundtrip() { + let mut rng = rng(); + let root = TestTH::rand_domain(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + // Encode to SSZ + let encoded = public_key.as_ssz_bytes(); + + // Check expected size: (7 + 5) * 4 = 48 bytes + assert_eq!(encoded.len(), 48); + + // Decode from SSZ + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + // Check fields match + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_deterministic() { + let mut rng = rng(); + let root = TestTH::rand_domain(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + // Encode multiple times + let encoded1 = public_key.as_ssz_bytes(); + let encoded2 = public_key.as_ssz_bytes(); + + // Should be identical + assert_eq!(encoded1, encoded2); + } + + #[test] + fn test_public_key_ssz_zero_values() { + let root = [F::ZERO; 7]; + let parameter = [F::ZERO; 5]; + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_max_values() { + use p3_field::PrimeField32; + + let max_val = F::ORDER_U32 - 1; + let root = [F::new(max_val); 7]; + let parameter = [F::new(max_val); 5]; + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_invalid_length_too_short() { + let bytes = vec![0u8; 47]; // Should be 48 bytes + let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); + assert!(result.is_err()); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { + len: 47, + expected: 48 + }) + )); + } + + #[test] + fn test_public_key_ssz_invalid_length_too_long() { + let bytes = vec![0u8; 49]; // Should be 48 bytes + let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); + assert!(result.is_err()); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { + len: 49, + expected: 48 + }) + )); + } + + #[test] + fn test_public_key_ssz_fixed_len_trait() { + assert!( as Encode>::is_ssz_fixed_len()); + assert_eq!( + as Encode>::ssz_fixed_len(), + 48 + ); + } + + #[test] + fn test_public_key_ssz_specific_values() { + // Test with specific known values to verify byte ordering + let root = [ + F::new(1), + F::new(2), + F::new(3), + F::new(4), + F::new(5), + F::new(6), + F::new(7), + ]; + let parameter = [F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)]; + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + + // Check first few bytes (little-endian encoding of 1) + assert_eq!(&encoded[0..4], &[1, 0, 0, 0]); + // Check encoding of 2 + assert_eq!(&encoded[4..8], &[2, 0, 0, 0]); + // Check encoding of 10 (first parameter value) + assert_eq!(&encoded[28..32], &[10, 0, 0, 0]); + + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } } From 0d1259a537759d4463eb98c3ab2256a4c7981257 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sun, 23 Nov 2025 22:02:03 +0100 Subject: [PATCH 2/6] fix clippy --- src/signature/generalized_xmss.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index a2f7f5f..5ab4f37 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -549,12 +549,12 @@ where buf.reserve((HASH_LEN + PARAM_LEN) * F::NUM_BYTES); // Encode root - for elem in self.root.iter() { + for elem in &self.root { let value = elem.as_canonical_u32(); buf.extend_from_slice(&value.to_le_bytes()); } // Encode parameter - for elem in self.parameter.iter() { + for elem in &self.parameter { let value = elem.as_canonical_u32(); buf.extend_from_slice(&value.to_le_bytes()); } From 8257be8ed2bd623f9260ef191095fcdc15a46805 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 24 Nov 2025 19:05:53 +0100 Subject: [PATCH 3/6] better abstraction --- src/array.rs | 358 +++++++++++++++++++++++++++ src/lib.rs | 1 + src/signature/generalized_xmss.rs | 80 ++---- src/simd_utils.rs | 58 ++--- src/symmetric/tweak_hash.rs | 13 +- src/symmetric/tweak_hash/poseidon.rs | 42 ++-- 6 files changed, 449 insertions(+), 103 deletions(-) create mode 100644 src/array.rs diff --git a/src/array.rs b/src/array.rs new file mode 100644 index 0000000..02ef570 --- /dev/null +++ b/src/array.rs @@ -0,0 +1,358 @@ +use serde::{Deserialize, Deserializer, Serialize, de::Visitor}; +use ssz::{Decode, DecodeError, Encode}; +use std::ops::{Deref, DerefMut}; + +use crate::F; +use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable}; + +/// A wrapper around an array of field elements that implements SSZ Encode/Decode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct FieldArray(pub [F; N]); + +impl Deref for FieldArray { + type Target = [F; N]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FieldArray { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From<[F; N]> for FieldArray { + fn from(arr: [F; N]) -> Self { + Self(arr) + } +} + +impl From> for [F; N] { + fn from(field_array: FieldArray) -> Self { + field_array.0 + } +} + +impl Encode for FieldArray { + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + N * F::NUM_BYTES + } + + fn ssz_bytes_len(&self) -> usize { + N * F::NUM_BYTES + } + + fn ssz_append(&self, buf: &mut Vec) { + buf.reserve(N * F::NUM_BYTES); + for elem in &self.0 { + let value = elem.as_canonical_u32(); + buf.extend_from_slice(&value.to_le_bytes()); + } + } +} + +impl Decode for FieldArray { + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + N * F::NUM_BYTES + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let expected_len = N * F::NUM_BYTES; + if bytes.len() != expected_len { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: expected_len, + }); + } + + let arr = std::array::from_fn(|i| { + let start = i * F::NUM_BYTES; + let chunk = bytes[start..start + F::NUM_BYTES].try_into().unwrap(); + F::new(u32::from_le_bytes(chunk)) + }); + + Ok(Self(arr)) + } +} + +impl Serialize for FieldArray { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq(self.0.iter().map(|elem| elem.as_canonical_u32())) + } +} + +impl<'de, const N: usize> Deserialize<'de> for FieldArray { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldArrayVisitor; + + impl<'de, const N: usize> Visitor<'de> for FieldArrayVisitor { + type Value = FieldArray; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an array of {} field elements", N) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut arr = [F::ZERO; N]; + for (i, p) in arr.iter_mut().enumerate() { + let val: u32 = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?; + *p = F::new(val); + } + Ok(FieldArray(arr)) + } + } + + deserializer.deserialize_seq(FieldArrayVisitor::) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + /// Small parameter arrays + const SMALL_SIZE: usize = 5; + /// Hash output size + const MEDIUM_SIZE: usize = 7; + /// Larger parameter arrays + const LARGE_SIZE: usize = 44; + + #[test] + fn test_ssz_roundtrip_zero_values() { + // Start with an array of zeros + let original = FieldArray([F::ZERO; SMALL_SIZE]); + + // Encode to bytes using SSZ + let encoded = original.as_ssz_bytes(); + + // Decode back from bytes + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode valid SSZ bytes"); + + // Verify round-trip preserves the value + assert_eq!(original, decoded, "Round-trip failed for zero values"); + } + + #[test] + fn test_ssz_roundtrip_max_values() { + // Create array with maximum valid field values + let max_val = F::ORDER_U32 - 1; + let original = FieldArray([F::new(max_val); MEDIUM_SIZE]); + + // Perform round-trip encoding/decoding + let encoded = original.as_ssz_bytes(); + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode max values"); + + // Verify the values survived the round-trip + assert_eq!(original, decoded, "Round-trip failed for max values"); + } + + #[test] + fn test_ssz_roundtrip_specific_values() { + // Create an array with sequential values for easy verification + let original = FieldArray([F::new(1), F::new(2), F::new(3), F::new(4), F::new(5)]); + + // Encode and verify the byte representation + let encoded = original.as_ssz_bytes(); + + // Each u32 should be encoded as F::NUM_BYTES bytes in little-endian + assert_eq!( + &encoded[0..F::NUM_BYTES], + &[1, 0, 0, 0], + "First element encoding incorrect" + ); + assert_eq!( + &encoded[F::NUM_BYTES..2 * F::NUM_BYTES], + &[2, 0, 0, 0], + "Second element encoding incorrect" + ); + assert_eq!( + &encoded[2 * F::NUM_BYTES..3 * F::NUM_BYTES], + &[3, 0, 0, 0], + "Third element encoding incorrect" + ); + + // Decode and verify round-trip + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode specific values"); + + assert_eq!(original, decoded, "Round-trip failed for specific values"); + } + + #[test] + fn test_ssz_encoding_deterministic() { + let mut rng = rand::rng(); + + // Create a random field array + let field_array = FieldArray(rng.random::<[F; SMALL_SIZE]>()); + + // Encode it multiple times + let encoding1 = field_array.as_ssz_bytes(); + let encoding2 = field_array.as_ssz_bytes(); + let encoding3 = field_array.as_ssz_bytes(); + + // All encodings should be identical + assert_eq!(encoding1, encoding2, "Encoding not deterministic (1 vs 2)"); + assert_eq!(encoding2, encoding3, "Encoding not deterministic (2 vs 3)"); + } + + #[test] + fn test_ssz_encoded_size() { + let field_array = FieldArray([F::ZERO; LARGE_SIZE]); + let encoded = field_array.as_ssz_bytes(); + + // Verify the encoded size matches expectations + let expected_size = LARGE_SIZE * F::NUM_BYTES; + assert_eq!( + encoded.len(), + expected_size, + "Encoded size should be {} bytes (array of {} elements, {} bytes each)", + expected_size, + LARGE_SIZE, + F::NUM_BYTES + ); + + // Also verify the trait method reports the same size + assert_eq!( + field_array.ssz_bytes_len(), + expected_size, + "ssz_bytes_len() should match actual encoded size" + ); + } + + #[test] + fn test_ssz_decode_rejects_wrong_length() { + let expected_len = SMALL_SIZE * F::NUM_BYTES; + + // Test buffer that's too short (missing one byte) + let too_short = vec![0u8; expected_len - 1]; + let result = FieldArray::::from_ssz_bytes(&too_short); + assert!(result.is_err(), "Should reject buffer that's too short"); + if let Err(DecodeError::InvalidByteLength { len, expected }) = result { + assert_eq!(len, expected_len - 1); + assert_eq!(expected, expected_len); + } else { + panic!("Expected InvalidByteLength error"); + } + + // Test buffer that's too long (extra byte) + let too_long = vec![0u8; expected_len + 1]; + let result = FieldArray::::from_ssz_bytes(&too_long); + assert!(result.is_err(), "Should reject buffer that's too long"); + if let Err(DecodeError::InvalidByteLength { len, expected }) = result { + assert_eq!(len, expected_len + 1); + assert_eq!(expected, expected_len); + } else { + panic!("Expected InvalidByteLength error"); + } + } + + #[test] + fn test_ssz_fixed_len_trait_methods() { + // Arrays are always fixed-length in SSZ + assert!( + as Encode>::is_ssz_fixed_len(), + "FieldArray should report as fixed-length (Encode)" + ); + assert!( + as Decode>::is_ssz_fixed_len(), + "FieldArray should report as fixed-length (Decode)" + ); + + // The fixed length should be N * F::NUM_BYTES + let expected_len = SMALL_SIZE * F::NUM_BYTES; + assert_eq!( + as Encode>::ssz_fixed_len(), + expected_len, + "Encode::ssz_fixed_len() incorrect" + ); + assert_eq!( + as Decode>::ssz_fixed_len(), + expected_len, + "Decode::ssz_fixed_len() incorrect" + ); + } + + proptest! { + #[test] + fn proptest_ssz_roundtrip_large( + values in prop::collection::vec(0u32..F::ORDER_U32, LARGE_SIZE) + ) { + // Convert Vec to array for large sizes + let arr: [F; LARGE_SIZE] = std::array::from_fn(|i| F::new(values[i])); + let original = FieldArray(arr); + + let encoded = original.as_ssz_bytes(); + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Valid SSZ bytes should always decode"); + + prop_assert_eq!(original, decoded); + } + + #[test] + fn proptest_ssz_deterministic( + values in prop::array::uniform5(0u32..F::ORDER_U32) + ) { + let arr = values.map(F::new); + let field_array = FieldArray(arr); + + // Encode twice and verify both encodings are identical + let encoding1 = field_array.as_ssz_bytes(); + let encoding2 = field_array.as_ssz_bytes(); + + prop_assert_eq!(encoding1, encoding2); + } + + #[test] + fn proptest_ssz_size_invariant( + values in prop::array::uniform5(0u32..F::ORDER_U32) + ) { + let arr = values.map(F::new); + let field_array = FieldArray(arr); + + let encoded = field_array.as_ssz_bytes(); + let expected_size = SMALL_SIZE * F::NUM_BYTES; + + prop_assert_eq!(encoded.len(), expected_size); + prop_assert_eq!(field_array.ssz_bytes_len(), expected_size); + } + } + + #[test] + fn test_equality() { + let arr1 = FieldArray([F::new(1), F::new(2), F::new(3)]); + let arr2 = FieldArray([F::new(1), F::new(2), F::new(3)]); + let arr3 = FieldArray([F::new(1), F::new(2), F::new(4)]); + + // Equal arrays should be equal + assert_eq!(arr1, arr2); + + // Different arrays should not be equal + assert_ne!(arr1, arr3); + assert_ne!(arr2, arr3); + } +} diff --git a/src/lib.rs b/src/lib.rs index 53f51b6..af6ffd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00; type F = KoalaBear; pub(crate) type PackedF = ::Packing; +pub(crate) mod array; pub(crate) mod hypercube; pub(crate) mod inc_encoding; pub mod signature; diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 5ab4f37..ff2e36f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -5,7 +5,7 @@ use rayon::prelude::*; use serde::{Deserialize, Serialize}; use crate::{ - F, MESSAGE_LENGTH, + MESSAGE_LENGTH, inc_encoding::IncomparableEncoding, signature::SignatureSchemeSecretKey, symmetric::{ @@ -17,7 +17,6 @@ use crate::{ use super::{SignatureScheme, SigningError}; -use p3_field::{PrimeField32, RawDataSerializable}; use ssz::{Decode, DecodeError, Encode}; /// Implementation of the generalized XMSS signature scheme @@ -528,54 +527,36 @@ where } } -impl Encode for GeneralizedXMSSPublicKey -where - TH: TweakableHash, -{ +impl Encode for GeneralizedXMSSPublicKey { fn is_ssz_fixed_len() -> bool { - true + ::is_ssz_fixed_len() && ::is_ssz_fixed_len() } fn ssz_fixed_len() -> usize { - (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + ::ssz_fixed_len() + ::ssz_fixed_len() } fn ssz_bytes_len(&self) -> usize { - (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + self.root.ssz_bytes_len() + self.parameter.ssz_bytes_len() } fn ssz_append(&self, buf: &mut Vec) { - // Reserve space for the output - buf.reserve((HASH_LEN + PARAM_LEN) * F::NUM_BYTES); - - // Encode root - for elem in &self.root { - let value = elem.as_canonical_u32(); - buf.extend_from_slice(&value.to_le_bytes()); - } - // Encode parameter - for elem in &self.parameter { - let value = elem.as_canonical_u32(); - buf.extend_from_slice(&value.to_le_bytes()); - } + self.root.ssz_append(buf); + self.parameter.ssz_append(buf); } } -impl Decode for GeneralizedXMSSPublicKey -where - TH: TweakableHash, -{ +impl Decode for GeneralizedXMSSPublicKey { fn is_ssz_fixed_len() -> bool { - true + ::is_ssz_fixed_len() && ::is_ssz_fixed_len() } fn ssz_fixed_len() -> usize { - (HASH_LEN + PARAM_LEN) * F::NUM_BYTES + ::ssz_fixed_len() + ::ssz_fixed_len() } fn from_ssz_bytes(bytes: &[u8]) -> Result { - let expected_len = (HASH_LEN + PARAM_LEN) * F::NUM_BYTES; - + let expected_len = ::ssz_fixed_len(); if bytes.len() != expected_len { return Err(DecodeError::InvalidByteLength { len: bytes.len(), @@ -583,25 +564,11 @@ where }); } - // We know this is safe because of the length check above. - let (root_bytes, param_bytes) = bytes.split_at(HASH_LEN * F::NUM_BYTES); - - // Define a helper closure to decode an array from bytes. - let decode_array = |chunk: &[u8]| -> F { - let val = u32::from_le_bytes(chunk.try_into().unwrap()); - F::new(val) - }; - - // Construct the root and parameter arrays. - let root = std::array::from_fn(|i| { - let start = i * F::NUM_BYTES; - decode_array(&root_bytes[start..start + F::NUM_BYTES]) - }); + let root_len = ::ssz_fixed_len(); + let (root_bytes, param_bytes) = bytes.split_at(root_len); - let parameter = std::array::from_fn(|i| { - let start = i * F::NUM_BYTES; - decode_array(¶m_bytes[start..start + F::NUM_BYTES]) - }); + let root = TH::Domain::from_ssz_bytes(root_bytes)?; + let parameter = TH::Parameter::from_ssz_bytes(param_bytes)?; Ok(Self { root, parameter }) } @@ -616,6 +583,7 @@ pub mod instantiations_poseidon_top_level; #[cfg(test)] mod tests { use crate::{ + array::FieldArray, inc_encoding::target_sum::TargetSumEncoding, signature::test_templates::test_signature_scheme_correctness, symmetric::{ @@ -627,7 +595,7 @@ mod tests { use super::*; - use crate::symmetric::tweak_hash::poseidon::PoseidonTweakHash; + use crate::{F, symmetric::tweak_hash::poseidon::PoseidonTweakHash}; use p3_field::PrimeCharacteristicRing; use rand::rng; use ssz::{Decode, Encode}; @@ -810,8 +778,8 @@ mod tests { #[test] fn test_public_key_ssz_zero_values() { - let root = [F::ZERO; 7]; - let parameter = [F::ZERO; 5]; + let root = FieldArray([F::ZERO; 7]); + let parameter = FieldArray([F::ZERO; 5]); let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; @@ -828,8 +796,8 @@ mod tests { use p3_field::PrimeField32; let max_val = F::ORDER_U32 - 1; - let root = [F::new(max_val); 7]; - let parameter = [F::new(max_val); 5]; + let root = FieldArray([F::new(max_val); 7]); + let parameter = FieldArray([F::new(max_val); 5]); let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; @@ -881,7 +849,7 @@ mod tests { #[test] fn test_public_key_ssz_specific_values() { // Test with specific known values to verify byte ordering - let root = [ + let root = FieldArray([ F::new(1), F::new(2), F::new(3), @@ -889,8 +857,8 @@ mod tests { F::new(5), F::new(6), F::new(7), - ]; - let parameter = [F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)]; + ]); + let parameter = FieldArray([F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)]); let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; diff --git a/src/simd_utils.rs b/src/simd_utils.rs index 435422e..c9c74bc 100644 --- a/src/simd_utils.rs +++ b/src/simd_utils.rs @@ -2,17 +2,17 @@ use core::array; use p3_field::PackedValue; -use crate::{F, PackedF}; +use crate::{PackedF, array::FieldArray}; /// Packs scalar arrays into SIMD-friendly vertical layout. /// -/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`. +/// Transposes from horizontal layout `[FieldArray; WIDTH]` to vertical layout `[PackedF; N]`. /// -/// Input layout (horizontal): each row is one complete array +/// Input layout (horizontal): each FieldArray is one complete array /// ```text -/// data[0] = [a0, a1, a2, ..., aN] -/// data[1] = [b0, b1, b2, ..., bN] -/// data[2] = [c0, c1, c2, ..., cN] +/// data[0] = FieldArray([a0, a1, a2, ..., aN]) +/// data[1] = FieldArray([b0, b1, b2, ..., bN]) +/// data[2] = FieldArray([c0, c1, c2, ..., cN]) /// ... /// ``` /// @@ -27,16 +27,16 @@ use crate::{F, PackedF}; /// This vertical packing enables efficient SIMD operations where a single instruction /// processes the same element position across multiple arrays simultaneously. #[inline] -pub fn pack_array(data: &[[F; N]]) -> [PackedF; N] { +pub fn pack_array(data: &[FieldArray]) -> [PackedF; N] { array::from_fn(|i| PackedF::from_fn(|j| data[j][i])) } /// Unpacks SIMD vertical layout back into scalar arrays. /// -/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`. +/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[FieldArray; WIDTH]`. /// /// This is the inverse operation of `pack_array`. The output buffer must be preallocated -/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`. +/// with size `[WIDTH]` where `WIDTH = PackedF::WIDTH`, and each element is a `FieldArray`. /// /// Input layout (vertical): each PackedF holds one element from each array /// ```text @@ -46,15 +46,15 @@ pub fn pack_array(data: &[[F; N]]) -> [PackedF; N] { /// ... /// ``` /// -/// Output layout (horizontal): each row is one complete array +/// Output layout (horizontal): each FieldArray is one complete array /// ```text -/// output[0] = [a0, a1, a2, ..., aN] -/// output[1] = [b0, b1, b2, ..., bN] -/// output[2] = [c0, c1, c2, ..., cN] +/// output[0] = FieldArray([a0, a1, a2, ..., aN]) +/// output[1] = FieldArray([b0, b1, b2, ..., bN]) +/// output[2] = FieldArray([c0, c1, c2, ..., cN]) /// ... /// ``` #[inline] -pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F; N]]) { +pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [FieldArray]) { for (i, data) in packed_data.iter().enumerate().take(N) { let unpacked_v = data.as_slice(); for j in 0..PackedF::WIDTH { @@ -65,6 +65,8 @@ pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F #[cfg(test)] mod tests { + use crate::F; + use super::*; use p3_field::PrimeCharacteristicRing; use proptest::prelude::*; @@ -73,19 +75,19 @@ mod tests { #[test] fn test_pack_array_simple() { // Test with N=2 (2 field elements per array) - // Create WIDTH arrays of [F; 2] - let data: [[F; 2]; PackedF::WIDTH] = - array::from_fn(|i| [F::from_u64(i as u64), F::from_u64((i + 100) as u64)]); + // Create WIDTH arrays wrapped in FieldArray + let data: [FieldArray<2>; PackedF::WIDTH] = + array::from_fn(|i| FieldArray([F::from_u64(i as u64), F::from_u64((i + 100) as u64)])); let packed = pack_array(&data); // Check that packed[0] contains all first elements - for (lane, &expected) in data.iter().enumerate() { + for (lane, expected) in data.iter().enumerate() { assert_eq!(packed[0].as_slice()[lane], expected[0]); } // Check that packed[1] contains all second elements - for (lane, &expected) in data.iter().enumerate() { + for (lane, expected) in data.iter().enumerate() { assert_eq!(packed[1].as_slice()[lane], expected[1]); } } @@ -99,7 +101,7 @@ mod tests { ]; // Unpack - let mut output = [[F::ZERO; 2]; PackedF::WIDTH]; + let mut output = [FieldArray([F::ZERO; 2]); PackedF::WIDTH]; unpack_array(&packed, &mut output); // Verify @@ -112,12 +114,12 @@ mod tests { #[test] fn test_pack_preserves_element_order() { // Create data where each array has sequential values - let data: [[F; 3]; PackedF::WIDTH] = array::from_fn(|i| { - [ + let data: [FieldArray<3>; PackedF::WIDTH] = array::from_fn(|i| { + FieldArray([ F::from_u64((i * 3) as u64), F::from_u64((i * 3 + 1) as u64), F::from_u64((i * 3 + 2) as u64), - ] + ]) }); let packed = pack_array(&data); @@ -143,7 +145,7 @@ mod tests { PackedF::from_fn(|i| F::from_u64((i * 3 + 2) as u64)), ]; - let mut output = [[F::ZERO; 3]; PackedF::WIDTH]; + let mut output = [FieldArray([F::ZERO; 3]); PackedF::WIDTH]; unpack_array(&packed, &mut output); // Verify each array has sequential values @@ -161,14 +163,14 @@ mod tests { ) { let mut rng = rand::rng(); - // Generate random data with N=10 - let original: [[F; 10]; PackedF::WIDTH] = array::from_fn(|_| { - array::from_fn(|_| rng.random()) + // Generate random data with N=10, using FieldArray + let original: [FieldArray<10>; PackedF::WIDTH] = array::from_fn(|_| { + FieldArray(array::from_fn(|_| rng.random())) }); // Pack and unpack let packed = pack_array(&original); - let mut unpacked = [[F::ZERO; 10]; PackedF::WIDTH]; + let mut unpacked = [FieldArray([F::ZERO; 10]); PackedF::WIDTH]; unpack_array(&packed, &mut unpacked); // Verify roundtrip diff --git a/src/symmetric/tweak_hash.rs b/src/symmetric/tweak_hash.rs index dae6d1c..b12fbe1 100644 --- a/src/symmetric/tweak_hash.rs +++ b/src/symmetric/tweak_hash.rs @@ -1,5 +1,6 @@ use rand::Rng; use serde::{Serialize, de::DeserializeOwned}; +use ssz::{Decode, Encode}; use crate::symmetric::prf::Pseudorandom; @@ -17,13 +18,21 @@ use crate::symmetric::prf::Pseudorandom; /// applications in Merkle trees. pub trait TweakableHash { /// Public parameter type for the hash function - type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned; + type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned + Encode + Decode; /// Tweak type for domain separation type Tweak; /// Domain element type (defines output and input types to the hash) - type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned; + type Domain: Copy + + PartialEq + + Sized + + Send + + Sync + + Serialize + + DeserializeOwned + + Encode + + Decode; /// Generates a random public parameter. fn rand_parameter(rng: &mut R) -> Self::Parameter; diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index a584907..c650c8b 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -3,10 +3,10 @@ use core::array; use p3_field::{Algebra, PackedValue, PrimeCharacteristicRing, PrimeField64}; use p3_symmetric::CryptographicPermutation; use rayon::prelude::*; -use serde::{Serialize, de::DeserializeOwned}; use crate::TWEAK_SEPARATOR_FOR_CHAIN_HASH; use crate::TWEAK_SEPARATOR_FOR_TREE_HASH; +use crate::array::FieldArray; use crate::poseidon2_16; use crate::poseidon2_24; use crate::simd_utils::{pack_array, unpack_array}; @@ -263,22 +263,19 @@ impl< const CAPACITY: usize, const NUM_CHUNKS: usize, > TweakableHash for PoseidonTweakHash -where - [F; PARAMETER_LEN]: Serialize + DeserializeOwned, - [F; HASH_LEN]: Serialize + DeserializeOwned, { - type Parameter = [F; PARAMETER_LEN]; + type Parameter = FieldArray; type Tweak = PoseidonTweak; - type Domain = [F; HASH_LEN]; + type Domain = FieldArray; fn rand_parameter(rng: &mut R) -> Self::Parameter { - rng.random() + FieldArray(rng.random()) } fn rand_domain(rng: &mut R) -> Self::Domain { - rng.random() + FieldArray(rng.random()) } fn tree_tweak(level: u8, pos_in_level: u32) -> Self::Tweak { @@ -318,7 +315,12 @@ where .chain(single.iter()) .copied() .collect(); - poseidon_compress::(&perm, &combined_input) + FieldArray( + poseidon_compress::( + &perm, + &combined_input, + ), + ) } [left, right] => { @@ -331,7 +333,12 @@ where .chain(right.iter()) .copied() .collect(); - poseidon_compress::(&perm, &combined_input) + FieldArray( + poseidon_compress::( + &perm, + &combined_input, + ), + ) } _ if message.len() > 2 => { @@ -340,7 +347,7 @@ where let combined_input: Vec = parameter .iter() .chain(tweak_fe.iter()) - .chain(message.iter().flatten()) + .chain(message.iter().flat_map(|x| x.iter())) .copied() .collect(); @@ -354,13 +361,13 @@ where poseidon_safe_domain_separator::( &perm, &lengths, ); - poseidon_sponge::( + FieldArray(poseidon_sponge::( &perm, &capacity_value, &combined_input, - ) + )) } - _ => [F::ONE; HASH_LEN], // Unreachable case, added for safety + _ => FieldArray([F::ONE; HASH_LEN]), // Unreachable case, added for safety } } @@ -394,7 +401,7 @@ where let width = PackedF::WIDTH; // Allocate output buffer for all leaves. - let mut leaves = vec![[F::ZERO; HASH_LEN]; epochs.len()]; + let mut leaves = vec![FieldArray([F::ZERO; HASH_LEN]); epochs.len()]; // PREPARE PACKED CONSTANTS @@ -442,7 +449,7 @@ where let mut packed_chains: [[PackedF; HASH_LEN]; NUM_CHUNKS] = array::from_fn(|c_idx| { // Generate starting points for this chain across all epochs. - let starts: [[F; HASH_LEN]; PackedF::WIDTH] = array::from_fn(|lane| { + let starts: [_; PackedF::WIDTH] = array::from_fn(|lane| { PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64).into() }); @@ -543,7 +550,8 @@ where // // Convert from vertical packing back to scalar layout. // Each lane becomes one leaf in the output slice. - + // + // No unsafe transmute needed - unpack_array accepts &mut [FieldArray] directly. unpack_array(&packed_leaves, leaves_chunk); }); From 19c926a3e01fb8a199416072fd1e12994d3b0daa Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 24 Nov 2025 19:10:31 +0100 Subject: [PATCH 4/6] trait bound for public key --- src/signature.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/signature.rs b/src/signature.rs index 63d9802..0d04672 100644 --- a/src/signature.rs +++ b/src/signature.rs @@ -3,6 +3,7 @@ use std::ops::Range; use crate::MESSAGE_LENGTH; use rand::Rng; use serde::{Serialize, de::DeserializeOwned}; +use ssz::{Decode, Encode}; use thiserror::Error; /// Error enum for the signing process. @@ -96,7 +97,9 @@ pub trait SignatureScheme { /// The public key used for verification. /// /// The key must be serializable to allow for network transmission and storage. - type PublicKey: Serialize + DeserializeOwned; + /// + /// We must support SSZ encoding for Ethereum consensus layer compatibility. + type PublicKey: Serialize + DeserializeOwned + Encode + Decode; /// The secret key used for signing. /// From 3c2978c5a7923c05cb4d00f4798c86b3c36bd7d8 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 24 Nov 2025 19:10:57 +0100 Subject: [PATCH 5/6] clippy --- src/array.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array.rs b/src/array.rs index 02ef570..8cf8b38 100644 --- a/src/array.rs +++ b/src/array.rs @@ -91,7 +91,7 @@ impl Serialize for FieldArray { where S: serde::Serializer, { - serializer.collect_seq(self.0.iter().map(|elem| elem.as_canonical_u32())) + serializer.collect_seq(self.0.iter().map(PrimeField32::as_canonical_u32)) } } From 1cb2779cb4a9d77b9242f37e131e7f5d9d1bf9d2 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Mon, 24 Nov 2025 19:16:07 +0100 Subject: [PATCH 6/6] rm useless comment --- src/symmetric/tweak_hash/poseidon.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index c650c8b..3ae8db4 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -550,8 +550,6 @@ impl< // // Convert from vertical packing back to scalar layout. // Each lane becomes one leaf in the output slice. - // - // No unsafe transmute needed - unpack_array accepts &mut [FieldArray] directly. unpack_array(&packed_leaves, leaves_chunk); });