diff --git a/Cargo.toml b/Cargo.toml index 0aefe4a..0986766 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,9 @@ dashmap = "6.1.0" serde = { version = "1.0", features = ["derive", "alloc"] } thiserror = "2.0" +ssz = { package = "ethereum_ssz", version = "0.10.0" } +ssz_derive = { package = "ethereum_ssz_derive", version = "0.10.0" } + p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" } diff --git a/src/array.rs b/src/array.rs new file mode 100644 index 0000000..8cf8b38 --- /dev/null +++ b/src/array.rs @@ -0,0 +1,358 @@ +use serde::{Deserialize, Deserializer, Serialize, de::Visitor}; +use ssz::{Decode, DecodeError, Encode}; +use std::ops::{Deref, DerefMut}; + +use crate::F; +use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable}; + +/// A wrapper around an array of field elements that implements SSZ Encode/Decode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub struct FieldArray(pub [F; N]); + +impl Deref for FieldArray { + type Target = [F; N]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FieldArray { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From<[F; N]> for FieldArray { + fn from(arr: [F; N]) -> Self { + Self(arr) + } +} + +impl From> for [F; N] { + fn from(field_array: FieldArray) -> Self { + field_array.0 + } +} + +impl Encode for FieldArray { + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + N * F::NUM_BYTES + } + + fn ssz_bytes_len(&self) -> usize { + N * F::NUM_BYTES + } + + fn ssz_append(&self, buf: &mut Vec) { + buf.reserve(N * F::NUM_BYTES); + for elem in &self.0 { + let value = elem.as_canonical_u32(); + buf.extend_from_slice(&value.to_le_bytes()); + } + } +} + +impl Decode for FieldArray { + fn is_ssz_fixed_len() -> bool { + true + } + + fn ssz_fixed_len() -> usize { + N * F::NUM_BYTES + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let expected_len = N * F::NUM_BYTES; + if bytes.len() != expected_len { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: expected_len, + }); + } + + let arr = std::array::from_fn(|i| { + let start = i * F::NUM_BYTES; + let chunk = bytes[start..start + F::NUM_BYTES].try_into().unwrap(); + F::new(u32::from_le_bytes(chunk)) + }); + + Ok(Self(arr)) + } +} + +impl Serialize for FieldArray { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq(self.0.iter().map(PrimeField32::as_canonical_u32)) + } +} + +impl<'de, const N: usize> Deserialize<'de> for FieldArray { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldArrayVisitor; + + impl<'de, const N: usize> Visitor<'de> for FieldArrayVisitor { + type Value = FieldArray; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an array of {} field elements", N) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut arr = [F::ZERO; N]; + for (i, p) in arr.iter_mut().enumerate() { + let val: u32 = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?; + *p = F::new(val); + } + Ok(FieldArray(arr)) + } + } + + deserializer.deserialize_seq(FieldArrayVisitor::) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + /// Small parameter arrays + const SMALL_SIZE: usize = 5; + /// Hash output size + const MEDIUM_SIZE: usize = 7; + /// Larger parameter arrays + const LARGE_SIZE: usize = 44; + + #[test] + fn test_ssz_roundtrip_zero_values() { + // Start with an array of zeros + let original = FieldArray([F::ZERO; SMALL_SIZE]); + + // Encode to bytes using SSZ + let encoded = original.as_ssz_bytes(); + + // Decode back from bytes + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode valid SSZ bytes"); + + // Verify round-trip preserves the value + assert_eq!(original, decoded, "Round-trip failed for zero values"); + } + + #[test] + fn test_ssz_roundtrip_max_values() { + // Create array with maximum valid field values + let max_val = F::ORDER_U32 - 1; + let original = FieldArray([F::new(max_val); MEDIUM_SIZE]); + + // Perform round-trip encoding/decoding + let encoded = original.as_ssz_bytes(); + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode max values"); + + // Verify the values survived the round-trip + assert_eq!(original, decoded, "Round-trip failed for max values"); + } + + #[test] + fn test_ssz_roundtrip_specific_values() { + // Create an array with sequential values for easy verification + let original = FieldArray([F::new(1), F::new(2), F::new(3), F::new(4), F::new(5)]); + + // Encode and verify the byte representation + let encoded = original.as_ssz_bytes(); + + // Each u32 should be encoded as F::NUM_BYTES bytes in little-endian + assert_eq!( + &encoded[0..F::NUM_BYTES], + &[1, 0, 0, 0], + "First element encoding incorrect" + ); + assert_eq!( + &encoded[F::NUM_BYTES..2 * F::NUM_BYTES], + &[2, 0, 0, 0], + "Second element encoding incorrect" + ); + assert_eq!( + &encoded[2 * F::NUM_BYTES..3 * F::NUM_BYTES], + &[3, 0, 0, 0], + "Third element encoding incorrect" + ); + + // Decode and verify round-trip + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Failed to decode specific values"); + + assert_eq!(original, decoded, "Round-trip failed for specific values"); + } + + #[test] + fn test_ssz_encoding_deterministic() { + let mut rng = rand::rng(); + + // Create a random field array + let field_array = FieldArray(rng.random::<[F; SMALL_SIZE]>()); + + // Encode it multiple times + let encoding1 = field_array.as_ssz_bytes(); + let encoding2 = field_array.as_ssz_bytes(); + let encoding3 = field_array.as_ssz_bytes(); + + // All encodings should be identical + assert_eq!(encoding1, encoding2, "Encoding not deterministic (1 vs 2)"); + assert_eq!(encoding2, encoding3, "Encoding not deterministic (2 vs 3)"); + } + + #[test] + fn test_ssz_encoded_size() { + let field_array = FieldArray([F::ZERO; LARGE_SIZE]); + let encoded = field_array.as_ssz_bytes(); + + // Verify the encoded size matches expectations + let expected_size = LARGE_SIZE * F::NUM_BYTES; + assert_eq!( + encoded.len(), + expected_size, + "Encoded size should be {} bytes (array of {} elements, {} bytes each)", + expected_size, + LARGE_SIZE, + F::NUM_BYTES + ); + + // Also verify the trait method reports the same size + assert_eq!( + field_array.ssz_bytes_len(), + expected_size, + "ssz_bytes_len() should match actual encoded size" + ); + } + + #[test] + fn test_ssz_decode_rejects_wrong_length() { + let expected_len = SMALL_SIZE * F::NUM_BYTES; + + // Test buffer that's too short (missing one byte) + let too_short = vec![0u8; expected_len - 1]; + let result = FieldArray::::from_ssz_bytes(&too_short); + assert!(result.is_err(), "Should reject buffer that's too short"); + if let Err(DecodeError::InvalidByteLength { len, expected }) = result { + assert_eq!(len, expected_len - 1); + assert_eq!(expected, expected_len); + } else { + panic!("Expected InvalidByteLength error"); + } + + // Test buffer that's too long (extra byte) + let too_long = vec![0u8; expected_len + 1]; + let result = FieldArray::::from_ssz_bytes(&too_long); + assert!(result.is_err(), "Should reject buffer that's too long"); + if let Err(DecodeError::InvalidByteLength { len, expected }) = result { + assert_eq!(len, expected_len + 1); + assert_eq!(expected, expected_len); + } else { + panic!("Expected InvalidByteLength error"); + } + } + + #[test] + fn test_ssz_fixed_len_trait_methods() { + // Arrays are always fixed-length in SSZ + assert!( + as Encode>::is_ssz_fixed_len(), + "FieldArray should report as fixed-length (Encode)" + ); + assert!( + as Decode>::is_ssz_fixed_len(), + "FieldArray should report as fixed-length (Decode)" + ); + + // The fixed length should be N * F::NUM_BYTES + let expected_len = SMALL_SIZE * F::NUM_BYTES; + assert_eq!( + as Encode>::ssz_fixed_len(), + expected_len, + "Encode::ssz_fixed_len() incorrect" + ); + assert_eq!( + as Decode>::ssz_fixed_len(), + expected_len, + "Decode::ssz_fixed_len() incorrect" + ); + } + + proptest! { + #[test] + fn proptest_ssz_roundtrip_large( + values in prop::collection::vec(0u32..F::ORDER_U32, LARGE_SIZE) + ) { + // Convert Vec to array for large sizes + let arr: [F; LARGE_SIZE] = std::array::from_fn(|i| F::new(values[i])); + let original = FieldArray(arr); + + let encoded = original.as_ssz_bytes(); + let decoded = FieldArray::::from_ssz_bytes(&encoded) + .expect("Valid SSZ bytes should always decode"); + + prop_assert_eq!(original, decoded); + } + + #[test] + fn proptest_ssz_deterministic( + values in prop::array::uniform5(0u32..F::ORDER_U32) + ) { + let arr = values.map(F::new); + let field_array = FieldArray(arr); + + // Encode twice and verify both encodings are identical + let encoding1 = field_array.as_ssz_bytes(); + let encoding2 = field_array.as_ssz_bytes(); + + prop_assert_eq!(encoding1, encoding2); + } + + #[test] + fn proptest_ssz_size_invariant( + values in prop::array::uniform5(0u32..F::ORDER_U32) + ) { + let arr = values.map(F::new); + let field_array = FieldArray(arr); + + let encoded = field_array.as_ssz_bytes(); + let expected_size = SMALL_SIZE * F::NUM_BYTES; + + prop_assert_eq!(encoded.len(), expected_size); + prop_assert_eq!(field_array.ssz_bytes_len(), expected_size); + } + } + + #[test] + fn test_equality() { + let arr1 = FieldArray([F::new(1), F::new(2), F::new(3)]); + let arr2 = FieldArray([F::new(1), F::new(2), F::new(3)]); + let arr3 = FieldArray([F::new(1), F::new(2), F::new(4)]); + + // Equal arrays should be equal + assert_eq!(arr1, arr2); + + // Different arrays should not be equal + assert_ne!(arr1, arr3); + assert_ne!(arr2, arr3); + } +} diff --git a/src/lib.rs b/src/lib.rs index 53f51b6..af6ffd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00; type F = KoalaBear; pub(crate) type PackedF = ::Packing; +pub(crate) mod array; pub(crate) mod hypercube; pub(crate) mod inc_encoding; pub mod signature; diff --git a/src/signature.rs b/src/signature.rs index 63d9802..0d04672 100644 --- a/src/signature.rs +++ b/src/signature.rs @@ -3,6 +3,7 @@ use std::ops::Range; use crate::MESSAGE_LENGTH; use rand::Rng; use serde::{Serialize, de::DeserializeOwned}; +use ssz::{Decode, Encode}; use thiserror::Error; /// Error enum for the signing process. @@ -96,7 +97,9 @@ pub trait SignatureScheme { /// The public key used for verification. /// /// The key must be serializable to allow for network transmission and storage. - type PublicKey: Serialize + DeserializeOwned; + /// + /// We must support SSZ encoding for Ethereum consensus layer compatibility. + type PublicKey: Serialize + DeserializeOwned + Encode + Decode; /// The secret key used for signing. /// diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index 529ed98..ff2e36f 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -17,6 +17,8 @@ use crate::{ use super::{SignatureScheme, SigningError}; +use ssz::{Decode, DecodeError, Encode}; + /// Implementation of the generalized XMSS signature scheme /// from any incomparable encoding scheme and any tweakable hash /// @@ -525,6 +527,53 @@ where } } +impl Encode for GeneralizedXMSSPublicKey { + fn is_ssz_fixed_len() -> bool { + ::is_ssz_fixed_len() && ::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + ::ssz_fixed_len() + ::ssz_fixed_len() + } + + fn ssz_bytes_len(&self) -> usize { + self.root.ssz_bytes_len() + self.parameter.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + self.root.ssz_append(buf); + self.parameter.ssz_append(buf); + } +} + +impl Decode for GeneralizedXMSSPublicKey { + fn is_ssz_fixed_len() -> bool { + ::is_ssz_fixed_len() && ::is_ssz_fixed_len() + } + + fn ssz_fixed_len() -> usize { + ::ssz_fixed_len() + ::ssz_fixed_len() + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + let expected_len = ::ssz_fixed_len(); + if bytes.len() != expected_len { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: expected_len, + }); + } + + let root_len = ::ssz_fixed_len(); + let (root_bytes, param_bytes) = bytes.split_at(root_len); + + let root = TH::Domain::from_ssz_bytes(root_bytes)?; + let parameter = TH::Parameter::from_ssz_bytes(param_bytes)?; + + Ok(Self { root, parameter }) + } +} + /// Instantiations of the generalized XMSS signature scheme based on Poseidon2 pub mod instantiations_poseidon; /// Instantiations of the generalized XMSS signature scheme based on the @@ -534,6 +583,7 @@ pub mod instantiations_poseidon_top_level; #[cfg(test)] mod tests { use crate::{ + array::FieldArray, inc_encoding::target_sum::TargetSumEncoding, signature::test_templates::test_signature_scheme_correctness, symmetric::{ @@ -545,6 +595,13 @@ mod tests { use super::*; + use crate::{F, symmetric::tweak_hash::poseidon::PoseidonTweakHash}; + use p3_field::PrimeCharacteristicRing; + use rand::rng; + use ssz::{Decode, Encode}; + + type TestTH = PoseidonTweakHash<5, 7, 2, 9, 155>; + #[test] pub fn test_target_sum_poseidon() { // Note: do not use these parameters, they are just for testing @@ -679,4 +736,145 @@ mod tests { let (start, end_excl) = expand_activation_time::(12, 2); assert!((start == 2) && (end_excl == 4)); } + + #[test] + fn test_public_key_ssz_roundtrip() { + let mut rng = rng(); + let root = TestTH::rand_domain(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + // Encode to SSZ + let encoded = public_key.as_ssz_bytes(); + + // Check expected size: (7 + 5) * 4 = 48 bytes + assert_eq!(encoded.len(), 48); + + // Decode from SSZ + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + // Check fields match + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_deterministic() { + let mut rng = rng(); + let root = TestTH::rand_domain(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + // Encode multiple times + let encoded1 = public_key.as_ssz_bytes(); + let encoded2 = public_key.as_ssz_bytes(); + + // Should be identical + assert_eq!(encoded1, encoded2); + } + + #[test] + fn test_public_key_ssz_zero_values() { + let root = FieldArray([F::ZERO; 7]); + let parameter = FieldArray([F::ZERO; 5]); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_max_values() { + use p3_field::PrimeField32; + + let max_val = F::ORDER_U32 - 1; + let root = FieldArray([F::new(max_val); 7]); + let parameter = FieldArray([F::new(max_val); 5]); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } + + #[test] + fn test_public_key_ssz_invalid_length_too_short() { + let bytes = vec![0u8; 47]; // Should be 48 bytes + let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); + assert!(result.is_err()); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { + len: 47, + expected: 48 + }) + )); + } + + #[test] + fn test_public_key_ssz_invalid_length_too_long() { + let bytes = vec![0u8; 49]; // Should be 48 bytes + let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); + assert!(result.is_err()); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { + len: 49, + expected: 48 + }) + )); + } + + #[test] + fn test_public_key_ssz_fixed_len_trait() { + assert!( as Encode>::is_ssz_fixed_len()); + assert_eq!( + as Encode>::ssz_fixed_len(), + 48 + ); + } + + #[test] + fn test_public_key_ssz_specific_values() { + // Test with specific known values to verify byte ordering + let root = FieldArray([ + F::new(1), + F::new(2), + F::new(3), + F::new(4), + F::new(5), + F::new(6), + F::new(7), + ]); + let parameter = FieldArray([F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)]); + + let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + + let encoded = public_key.as_ssz_bytes(); + + // Check first few bytes (little-endian encoding of 1) + assert_eq!(&encoded[0..4], &[1, 0, 0, 0]); + // Check encoding of 2 + assert_eq!(&encoded[4..8], &[2, 0, 0, 0]); + // Check encoding of 10 (first parameter value) + assert_eq!(&encoded[28..32], &[10, 0, 0, 0]); + + let decoded = + GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + + assert_eq!(public_key.root, decoded.root); + assert_eq!(public_key.parameter, decoded.parameter); + } } diff --git a/src/simd_utils.rs b/src/simd_utils.rs index 435422e..c9c74bc 100644 --- a/src/simd_utils.rs +++ b/src/simd_utils.rs @@ -2,17 +2,17 @@ use core::array; use p3_field::PackedValue; -use crate::{F, PackedF}; +use crate::{PackedF, array::FieldArray}; /// Packs scalar arrays into SIMD-friendly vertical layout. /// -/// Transposes from horizontal layout `[[F; N]; WIDTH]` to vertical layout `[PackedF; N]`. +/// Transposes from horizontal layout `[FieldArray; WIDTH]` to vertical layout `[PackedF; N]`. /// -/// Input layout (horizontal): each row is one complete array +/// Input layout (horizontal): each FieldArray is one complete array /// ```text -/// data[0] = [a0, a1, a2, ..., aN] -/// data[1] = [b0, b1, b2, ..., bN] -/// data[2] = [c0, c1, c2, ..., cN] +/// data[0] = FieldArray([a0, a1, a2, ..., aN]) +/// data[1] = FieldArray([b0, b1, b2, ..., bN]) +/// data[2] = FieldArray([c0, c1, c2, ..., cN]) /// ... /// ``` /// @@ -27,16 +27,16 @@ use crate::{F, PackedF}; /// This vertical packing enables efficient SIMD operations where a single instruction /// processes the same element position across multiple arrays simultaneously. #[inline] -pub fn pack_array(data: &[[F; N]]) -> [PackedF; N] { +pub fn pack_array(data: &[FieldArray]) -> [PackedF; N] { array::from_fn(|i| PackedF::from_fn(|j| data[j][i])) } /// Unpacks SIMD vertical layout back into scalar arrays. /// -/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[[F; N]; WIDTH]`. +/// Transposes from vertical layout `[PackedF; N]` to horizontal layout `[FieldArray; WIDTH]`. /// /// This is the inverse operation of `pack_array`. The output buffer must be preallocated -/// with size `[WIDTH][N]` where `WIDTH = PackedF::WIDTH`. +/// with size `[WIDTH]` where `WIDTH = PackedF::WIDTH`, and each element is a `FieldArray`. /// /// Input layout (vertical): each PackedF holds one element from each array /// ```text @@ -46,15 +46,15 @@ pub fn pack_array(data: &[[F; N]]) -> [PackedF; N] { /// ... /// ``` /// -/// Output layout (horizontal): each row is one complete array +/// Output layout (horizontal): each FieldArray is one complete array /// ```text -/// output[0] = [a0, a1, a2, ..., aN] -/// output[1] = [b0, b1, b2, ..., bN] -/// output[2] = [c0, c1, c2, ..., cN] +/// output[0] = FieldArray([a0, a1, a2, ..., aN]) +/// output[1] = FieldArray([b0, b1, b2, ..., bN]) +/// output[2] = FieldArray([c0, c1, c2, ..., cN]) /// ... /// ``` #[inline] -pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F; N]]) { +pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [FieldArray]) { for (i, data) in packed_data.iter().enumerate().take(N) { let unpacked_v = data.as_slice(); for j in 0..PackedF::WIDTH { @@ -65,6 +65,8 @@ pub fn unpack_array(packed_data: &[PackedF; N], output: &mut [[F #[cfg(test)] mod tests { + use crate::F; + use super::*; use p3_field::PrimeCharacteristicRing; use proptest::prelude::*; @@ -73,19 +75,19 @@ mod tests { #[test] fn test_pack_array_simple() { // Test with N=2 (2 field elements per array) - // Create WIDTH arrays of [F; 2] - let data: [[F; 2]; PackedF::WIDTH] = - array::from_fn(|i| [F::from_u64(i as u64), F::from_u64((i + 100) as u64)]); + // Create WIDTH arrays wrapped in FieldArray + let data: [FieldArray<2>; PackedF::WIDTH] = + array::from_fn(|i| FieldArray([F::from_u64(i as u64), F::from_u64((i + 100) as u64)])); let packed = pack_array(&data); // Check that packed[0] contains all first elements - for (lane, &expected) in data.iter().enumerate() { + for (lane, expected) in data.iter().enumerate() { assert_eq!(packed[0].as_slice()[lane], expected[0]); } // Check that packed[1] contains all second elements - for (lane, &expected) in data.iter().enumerate() { + for (lane, expected) in data.iter().enumerate() { assert_eq!(packed[1].as_slice()[lane], expected[1]); } } @@ -99,7 +101,7 @@ mod tests { ]; // Unpack - let mut output = [[F::ZERO; 2]; PackedF::WIDTH]; + let mut output = [FieldArray([F::ZERO; 2]); PackedF::WIDTH]; unpack_array(&packed, &mut output); // Verify @@ -112,12 +114,12 @@ mod tests { #[test] fn test_pack_preserves_element_order() { // Create data where each array has sequential values - let data: [[F; 3]; PackedF::WIDTH] = array::from_fn(|i| { - [ + let data: [FieldArray<3>; PackedF::WIDTH] = array::from_fn(|i| { + FieldArray([ F::from_u64((i * 3) as u64), F::from_u64((i * 3 + 1) as u64), F::from_u64((i * 3 + 2) as u64), - ] + ]) }); let packed = pack_array(&data); @@ -143,7 +145,7 @@ mod tests { PackedF::from_fn(|i| F::from_u64((i * 3 + 2) as u64)), ]; - let mut output = [[F::ZERO; 3]; PackedF::WIDTH]; + let mut output = [FieldArray([F::ZERO; 3]); PackedF::WIDTH]; unpack_array(&packed, &mut output); // Verify each array has sequential values @@ -161,14 +163,14 @@ mod tests { ) { let mut rng = rand::rng(); - // Generate random data with N=10 - let original: [[F; 10]; PackedF::WIDTH] = array::from_fn(|_| { - array::from_fn(|_| rng.random()) + // Generate random data with N=10, using FieldArray + let original: [FieldArray<10>; PackedF::WIDTH] = array::from_fn(|_| { + FieldArray(array::from_fn(|_| rng.random())) }); // Pack and unpack let packed = pack_array(&original); - let mut unpacked = [[F::ZERO; 10]; PackedF::WIDTH]; + let mut unpacked = [FieldArray([F::ZERO; 10]); PackedF::WIDTH]; unpack_array(&packed, &mut unpacked); // Verify roundtrip diff --git a/src/symmetric/tweak_hash.rs b/src/symmetric/tweak_hash.rs index dae6d1c..b12fbe1 100644 --- a/src/symmetric/tweak_hash.rs +++ b/src/symmetric/tweak_hash.rs @@ -1,5 +1,6 @@ use rand::Rng; use serde::{Serialize, de::DeserializeOwned}; +use ssz::{Decode, Encode}; use crate::symmetric::prf::Pseudorandom; @@ -17,13 +18,21 @@ use crate::symmetric::prf::Pseudorandom; /// applications in Merkle trees. pub trait TweakableHash { /// Public parameter type for the hash function - type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned; + type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned + Encode + Decode; /// Tweak type for domain separation type Tweak; /// Domain element type (defines output and input types to the hash) - type Domain: Copy + PartialEq + Sized + Send + Sync + Serialize + DeserializeOwned; + type Domain: Copy + + PartialEq + + Sized + + Send + + Sync + + Serialize + + DeserializeOwned + + Encode + + Decode; /// Generates a random public parameter. fn rand_parameter(rng: &mut R) -> Self::Parameter; diff --git a/src/symmetric/tweak_hash/poseidon.rs b/src/symmetric/tweak_hash/poseidon.rs index a584907..3ae8db4 100644 --- a/src/symmetric/tweak_hash/poseidon.rs +++ b/src/symmetric/tweak_hash/poseidon.rs @@ -3,10 +3,10 @@ use core::array; use p3_field::{Algebra, PackedValue, PrimeCharacteristicRing, PrimeField64}; use p3_symmetric::CryptographicPermutation; use rayon::prelude::*; -use serde::{Serialize, de::DeserializeOwned}; use crate::TWEAK_SEPARATOR_FOR_CHAIN_HASH; use crate::TWEAK_SEPARATOR_FOR_TREE_HASH; +use crate::array::FieldArray; use crate::poseidon2_16; use crate::poseidon2_24; use crate::simd_utils::{pack_array, unpack_array}; @@ -263,22 +263,19 @@ impl< const CAPACITY: usize, const NUM_CHUNKS: usize, > TweakableHash for PoseidonTweakHash -where - [F; PARAMETER_LEN]: Serialize + DeserializeOwned, - [F; HASH_LEN]: Serialize + DeserializeOwned, { - type Parameter = [F; PARAMETER_LEN]; + type Parameter = FieldArray; type Tweak = PoseidonTweak; - type Domain = [F; HASH_LEN]; + type Domain = FieldArray; fn rand_parameter(rng: &mut R) -> Self::Parameter { - rng.random() + FieldArray(rng.random()) } fn rand_domain(rng: &mut R) -> Self::Domain { - rng.random() + FieldArray(rng.random()) } fn tree_tweak(level: u8, pos_in_level: u32) -> Self::Tweak { @@ -318,7 +315,12 @@ where .chain(single.iter()) .copied() .collect(); - poseidon_compress::(&perm, &combined_input) + FieldArray( + poseidon_compress::( + &perm, + &combined_input, + ), + ) } [left, right] => { @@ -331,7 +333,12 @@ where .chain(right.iter()) .copied() .collect(); - poseidon_compress::(&perm, &combined_input) + FieldArray( + poseidon_compress::( + &perm, + &combined_input, + ), + ) } _ if message.len() > 2 => { @@ -340,7 +347,7 @@ where let combined_input: Vec = parameter .iter() .chain(tweak_fe.iter()) - .chain(message.iter().flatten()) + .chain(message.iter().flat_map(|x| x.iter())) .copied() .collect(); @@ -354,13 +361,13 @@ where poseidon_safe_domain_separator::( &perm, &lengths, ); - poseidon_sponge::( + FieldArray(poseidon_sponge::( &perm, &capacity_value, &combined_input, - ) + )) } - _ => [F::ONE; HASH_LEN], // Unreachable case, added for safety + _ => FieldArray([F::ONE; HASH_LEN]), // Unreachable case, added for safety } } @@ -394,7 +401,7 @@ where let width = PackedF::WIDTH; // Allocate output buffer for all leaves. - let mut leaves = vec![[F::ZERO; HASH_LEN]; epochs.len()]; + let mut leaves = vec![FieldArray([F::ZERO; HASH_LEN]); epochs.len()]; // PREPARE PACKED CONSTANTS @@ -442,7 +449,7 @@ where let mut packed_chains: [[PackedF; HASH_LEN]; NUM_CHUNKS] = array::from_fn(|c_idx| { // Generate starting points for this chain across all epochs. - let starts: [[F; HASH_LEN]; PackedF::WIDTH] = array::from_fn(|lane| { + let starts: [_; PackedF::WIDTH] = array::from_fn(|lane| { PRF::get_domain_element(prf_key, epoch_chunk[lane], c_idx as u64).into() }); @@ -543,7 +550,6 @@ where // // Convert from vertical packing back to scalar layout. // Each lane becomes one leaf in the output slice. - unpack_array(&packed_leaves, leaves_chunk); });