diff --git a/src/array.rs b/src/array.rs index 9314ef1..3328508 100644 --- a/src/array.rs +++ b/src/array.rs @@ -3,6 +3,7 @@ use ssz::{Decode, DecodeError, Encode}; use std::ops::{Deref, DerefMut}; use crate::F; +use crate::serialization::Serializable; use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable}; /// A wrapper around an array of field elements that implements SSZ Encode/Decode. @@ -86,6 +87,8 @@ impl Decode for FieldArray { } } +impl Serializable for FieldArray {} + impl Serialize for FieldArray { fn serialize(&self, serializer: S) -> Result where diff --git a/src/inc_encoding.rs b/src/inc_encoding.rs index 3d88336..001e641 100644 --- a/src/inc_encoding.rs +++ b/src/inc_encoding.rs @@ -1,8 +1,8 @@ use rand::Rng; -use serde::{Serialize, de::DeserializeOwned}; use std::fmt::Debug; use crate::MESSAGE_LENGTH; +use crate::serialization::Serializable; /// Trait to model incomparable encoding schemes. /// These schemes allow to encode a message into a codeword. @@ -17,8 +17,8 @@ use crate::MESSAGE_LENGTH; /// x = (x_1,..,x_k) and x' = (x'_1,..,x'_k) we have /// x_i > x'_i for all i = 1,...,k. pub trait IncomparableEncoding { - type Parameter: Serialize + DeserializeOwned; - type Randomness: Serialize + DeserializeOwned; + type Parameter: Serializable; + type Randomness: Serializable; type Error: Debug; /// number of entries in a codeword diff --git a/src/lib.rs b/src/lib.rs index af6ffd8..ce6be6f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub(crate) type PackedF = ::Packing; pub(crate) mod array; pub(crate) mod hypercube; pub(crate) mod inc_encoding; +pub mod serialization; pub mod signature; pub(crate) mod simd_utils; pub(crate) mod symmetric; diff --git a/src/serialization.rs b/src/serialization.rs new file mode 100644 index 0000000..bee25fc --- /dev/null +++ b/src/serialization.rs @@ -0,0 +1,45 @@ +//! A unified serialization implementation + +use serde::{Serialize, de::DeserializeOwned}; +use ssz::{Decode, DecodeError, Encode}; + +/// A supertrait combining all serialization capabilities needed for leanSig types. +pub trait Serializable: Serialize + DeserializeOwned + Encode + Decode + Sized { + /// Converts this object to a canonical byte representation. + /// + /// # Canonical Format + /// + /// - All field elements are converted to canonical `u32` form (not Montgomery) + /// - All `u32` values are encoded as 4 bytes in little-endian order + /// + /// # Returns + /// + /// A `Vec` containing the canonical byte representation of this object. + fn to_bytes(&self) -> Vec { + // TODO: Update this to not use SSZ internally. + self.as_ssz_bytes() + } + + /// Parses an object from its canonical byte representation. + /// + /// # Canonical Format + /// + /// The input bytes must follow the same canonical format as `to_bytes()`: + /// - Field elements as canonical `u32` values (4 bytes, little-endian) + /// - Composite structures following SSZ layout rules + /// + /// # Arguments + /// + /// * `bytes` - The canonical binary data to parse + /// + /// # Returns + /// + /// - `Ok(Self)` if the bytes represent a valid object + /// - `Err(DecodeError)` if the bytes are malformed or invalid + fn from_bytes(bytes: &[u8]) -> Result { + // TODO: Update this to not use SSZ internally. + Self::from_ssz_bytes(bytes) + } +} + +impl Serializable for [u8; 32] {} diff --git a/src/signature.rs b/src/signature.rs index 0d04672..0972c93 100644 --- a/src/signature.rs +++ b/src/signature.rs @@ -1,9 +1,8 @@ use std::ops::Range; use crate::MESSAGE_LENGTH; +use crate::serialization::Serializable; use rand::Rng; -use serde::{Serialize, de::DeserializeOwned}; -use ssz::{Decode, Encode}; use thiserror::Error; /// Error enum for the signing process. @@ -99,17 +98,21 @@ pub trait SignatureScheme { /// The key must be serializable to allow for network transmission and storage. /// /// We must support SSZ encoding for Ethereum consensus layer compatibility. - type PublicKey: Serialize + DeserializeOwned + Encode + Decode; + type PublicKey: Serializable; /// The secret key used for signing. /// /// The key must be serializable for persistence and secure backup. - type SecretKey: SignatureSchemeSecretKey + Serialize + DeserializeOwned; + /// + /// We must support SSZ encoding for Ethereum consensus layer compatibility. + type SecretKey: SignatureSchemeSecretKey + Serializable; /// The signature object produced by the signing algorithm. /// /// The signature must be serializable to allow for network transmission and storage. - type Signature: Serialize + DeserializeOwned; + /// + /// We must support SSZ encoding for Ethereum consensus layer compatibility. + type Signature: Serializable; /// The maximum number of epochs supported by this signature scheme configuration, /// denoted as $L$ in the literature [DKKW25a, DKKW25b]. diff --git a/src/signature/generalized_xmss.rs b/src/signature/generalized_xmss.rs index ff2e36f..93b88a4 100644 --- a/src/signature/generalized_xmss.rs +++ b/src/signature/generalized_xmss.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::{ MESSAGE_LENGTH, inc_encoding::IncomparableEncoding, + serialization::Serializable, signature::SignatureSchemeSecretKey, symmetric::{ prf::Pseudorandom, @@ -47,6 +48,130 @@ pub struct GeneralizedXMSSSignature hashes: Vec, } +impl Encode for GeneralizedXMSSSignature { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + // SSZ Container: offset (4) + rho (fixed) + offset (4) + variable data + let offset_size = 4; + let rho_size = self.rho.ssz_bytes_len(); + let path_size = self.path.ssz_bytes_len(); + let hashes_size = self.hashes.ssz_bytes_len(); + + offset_size + rho_size + offset_size + path_size + hashes_size + } + + fn ssz_append(&self, buf: &mut Vec) { + // Appends the SSZ encoding to the buffer. + // + // SSZ Container encoding with fields interleaved in declaration order: + // - Field 1 (path): variable → write offset + // - Field 2 (rho): fixed → write data + // - Field 3 (hashes): variable → write offset + // + // Then write variable data in order: path, hashes + + // Calculate offsets (start of variable data) + let rho_size = self.rho.ssz_bytes_len(); + // offset + rho + offset + let fixed_size = 4 + rho_size + 4; + + let offset_path = fixed_size; + let offset_hashes = offset_path + self.path.ssz_bytes_len(); + + // 1. Encode offset for first variable field: path + buf.extend_from_slice(&(offset_path as u32).to_le_bytes()); + + // 2. Encode fixed field: rho + self.rho.ssz_append(buf); + + // 3. Encode offset for second variable field: hashes + buf.extend_from_slice(&(offset_hashes as u32).to_le_bytes()); + + // 4. Encode variable data in order + self.path.ssz_append(buf); + self.hashes.ssz_append(buf); + } +} + +impl Decode for GeneralizedXMSSSignature { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + // Decodes a generalized XMSS signature from SSZ bytes. + // + // Fields are interleaved: offset_path → rho → offset_hashes → variable data + + // Get fixed size of rho field + let rho_size = if ::is_ssz_fixed_len() { + ::ssz_fixed_len() + } else { + return Err(DecodeError::BytesInvalid( + "IE::Randomness must be fixed length".into(), + )); + }; + + // Minimum size: offset (4) + rho (fixed) + offset (4) + let min_size = 4 + rho_size + 4; + if bytes.len() < min_size { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: min_size, + }); + } + + // 1. Read offset for first variable field: path + let offset_path = u32::from_le_bytes(bytes[0..4].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: 4, + } + })?) as usize; + + // 2. Decode fixed field: rho + let rho = IE::Randomness::from_ssz_bytes(&bytes[4..4 + rho_size])?; + + // 3. Read offset for second variable field: hashes + let offset_hashes = + u32::from_le_bytes(bytes[4 + rho_size..8 + rho_size].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: 8 + rho_size, + } + })?) as usize; + + // Validate offset_path points to end of fixed part + let expected_offset_path = 4 + rho_size + 4; + if offset_path != expected_offset_path { + return Err(DecodeError::InvalidByteLength { + len: offset_path, + expected: expected_offset_path, + }); + } + + // Panic safety: Ensure offsets are monotonic and within bounds + // This prevents panic when creating slices below + if offset_path > offset_hashes || offset_hashes > bytes.len() { + return Err(DecodeError::BytesInvalid(format!( + "Invalid variable offsets: path={} hashes={} len={}", + offset_path, + offset_hashes, + bytes.len() + ))); + } + + // 4. Decode variable fields (now safe after bounds check) + let path = HashTreeOpening::::from_ssz_bytes(&bytes[offset_path..offset_hashes])?; + let hashes = Vec::::from_ssz_bytes(&bytes[offset_hashes..])?; + + Ok(Self { path, rho, hashes }) + } +} + /// Public key for GeneralizedXMSSSignatureScheme /// It contains a Merkle root and a parameter for the tweakable hash #[derive(Serialize, Deserialize)] @@ -70,15 +195,275 @@ pub struct GeneralizedXMSSSecretKey< > { prf_key: PRF::Key, parameter: TH::Parameter, - activation_epoch: usize, - num_active_epochs: usize, + activation_epoch: u64, + num_active_epochs: u64, top_tree: HashSubTree, - left_bottom_tree_index: usize, + left_bottom_tree_index: u64, left_bottom_tree: HashSubTree, right_bottom_tree: HashSubTree, _encoding_type: PhantomData, } +impl + Encode for GeneralizedXMSSSecretKey +{ + fn is_ssz_fixed_len() -> bool { + // It has variable length due to HashSubTree field + false + } + + fn ssz_bytes_len(&self) -> usize { + // Computes the SSZ encoded length. + // Format: Fields interleaved in declaration order with offsets for variable fields + + // Fixed-length fields (using u64 for platform independence) + let prf_key_size = self.prf_key.ssz_bytes_len(); + let parameter_size = self.parameter.ssz_bytes_len(); + let activation_epoch_size = 8; // u64 + let num_active_epochs_size = 8; // u64 + + // Variable fields need 4-byte offsets each + let offset_size = 4; + let top_tree_size = self.top_tree.ssz_bytes_len(); + + let left_bottom_tree_index_size = 8; // u64 + let left_bottom_tree_size = self.left_bottom_tree.ssz_bytes_len(); + let right_bottom_tree_size = self.right_bottom_tree.ssz_bytes_len(); + + prf_key_size + + parameter_size + + activation_epoch_size + + num_active_epochs_size + + offset_size // top_tree offset + + left_bottom_tree_index_size + + offset_size // left_bottom_tree offset + + offset_size // right_bottom_tree offset + + top_tree_size + + left_bottom_tree_size + + right_bottom_tree_size + } + + fn ssz_append(&self, buf: &mut Vec) { + // Appends the SSZ encoding to the buffer. + // + // SSZ Container encoding with fields interleaved in declaration order: + // - Field 1 (prf_key): fixed → write data + // - Field 2 (parameter): fixed → write data + // - Field 3 (activation_epoch): fixed → write data + // - Field 4 (num_active_epochs): fixed → write data + // - Field 5 (top_tree): variable → write offset + // - Field 6 (left_bottom_tree_index): fixed → write data + // - Field 7 (left_bottom_tree): variable → write offset + // - Field 8 (right_bottom_tree): variable → write offset + // + // Then write variable data in order: top_tree, left_bottom_tree, right_bottom_tree + + // Calculate sizes of fixed fields + let prf_key_size = self.prf_key.ssz_bytes_len(); + let parameter_size = self.parameter.ssz_bytes_len(); + + // Calculate start of variable data + let fixed_size = prf_key_size + parameter_size + 8 + 8 + 4 + 8 + 4 + 4; + + let offset_top_tree = fixed_size; + let offset_left_bottom = offset_top_tree + self.top_tree.ssz_bytes_len(); + let offset_right_bottom = offset_left_bottom + self.left_bottom_tree.ssz_bytes_len(); + + // 1. Encode fixed field: prf_key + self.prf_key.ssz_append(buf); + + // 2. Encode fixed field: parameter + self.parameter.ssz_append(buf); + + // 3. Encode fixed field: activation_epoch (u64) + buf.extend_from_slice(&self.activation_epoch.to_le_bytes()); + + // 4. Encode fixed field: num_active_epochs (u64) + buf.extend_from_slice(&self.num_active_epochs.to_le_bytes()); + + // 5. Encode offset for first variable field: top_tree + buf.extend_from_slice(&(offset_top_tree as u32).to_le_bytes()); + + // 6. Encode fixed field: left_bottom_tree_index (u64) + buf.extend_from_slice(&self.left_bottom_tree_index.to_le_bytes()); + + // 7. Encode offset for second variable field: left_bottom_tree + buf.extend_from_slice(&(offset_left_bottom as u32).to_le_bytes()); + + // 8. Encode offset for third variable field: right_bottom_tree + buf.extend_from_slice(&(offset_right_bottom as u32).to_le_bytes()); + + // 9. Encode variable data in order + self.top_tree.ssz_append(buf); + self.left_bottom_tree.ssz_append(buf); + self.right_bottom_tree.ssz_append(buf); + } +} + +impl + Decode for GeneralizedXMSSSecretKey +{ + fn is_ssz_fixed_len() -> bool { + false + } + + #[allow(clippy::too_many_lines)] + fn from_ssz_bytes(bytes: &[u8]) -> Result { + // Decodes a generalized XMSS secret key from SSZ bytes. + // + // Fields are interleaved: + // - prf_key + // - parameter + // - activation_epoch + // - num_active_epochs + // - offset_top_tree + // - left_bottom_tree_index + // - offset_left_bottom + // - offset_right_bottom + // - variable data + + // Get fixed sizes for prf_key and parameter + let prf_key_size = if ::is_ssz_fixed_len() { + ::ssz_fixed_len() + } else { + return Err(DecodeError::BytesInvalid( + "PRF::Key must be fixed length".into(), + )); + }; + + let parameter_size = if ::is_ssz_fixed_len() { + ::ssz_fixed_len() + } else { + return Err(DecodeError::BytesInvalid( + "TH::Parameter must be fixed length".into(), + )); + }; + + // Minimum size: prf_key + parameter + 3×u64 (24) + 3×offset (12) + let min_fixed_size = prf_key_size + parameter_size + 24 + 12; + if bytes.len() < min_fixed_size { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: min_fixed_size, + }); + } + + // Track current position + let mut pos = 0; + + // 1. Decode fixed field: prf_key + let prf_key = PRF::Key::from_ssz_bytes(&bytes[pos..pos + prf_key_size])?; + pos += prf_key_size; + + // 2. Decode fixed field: parameter + let parameter = TH::Parameter::from_ssz_bytes(&bytes[pos..pos + parameter_size])?; + pos += parameter_size; + + // 3. Decode fixed field: activation_epoch (u64) + let activation_epoch = + u64::from_le_bytes(bytes[pos..pos + 8].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 8, + } + })?); + pos += 8; + + // 4. Decode fixed field: num_active_epochs (u64) + let num_active_epochs = + u64::from_le_bytes(bytes[pos..pos + 8].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 8, + } + })?); + pos += 8; + + // 5. Read offset for first variable field: top_tree + let offset_top_tree = u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 4, + } + })?) as usize; + pos += 4; + + // 6. Decode fixed field: left_bottom_tree_index (u64) + let left_bottom_tree_index = + u64::from_le_bytes(bytes[pos..pos + 8].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 8, + } + })?); + pos += 8; + + // 7. Read offset for second variable field: left_bottom_tree + let offset_left_bottom = + u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 4, + } + })?) as usize; + pos += 4; + + // 8. Read offset for third variable field: right_bottom_tree + let offset_right_bottom = + u32::from_le_bytes(bytes[pos..pos + 4].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: pos + 4, + } + })?) as usize; + pos += 4; + + // Validate that fixed part ends at first offset + if pos != offset_top_tree { + return Err(DecodeError::InvalidByteLength { + len: pos, + expected: offset_top_tree, + }); + } + + // Panic safety: Ensure offsets are monotonic and within bounds + // + // This prevents panic when creating slices below + // Verify: offset_top <= offset_left <= offset_right <= bytes.len() + if offset_top_tree > offset_left_bottom + || offset_left_bottom > offset_right_bottom + || offset_right_bottom > bytes.len() + { + return Err(DecodeError::BytesInvalid(format!( + "Invalid variable offsets: top={} left={} right={} len={}", + offset_top_tree, + offset_left_bottom, + offset_right_bottom, + bytes.len() + ))); + } + + // 9. Decode variable fields (now safe after bounds check) + let top_tree = + HashSubTree::::from_ssz_bytes(&bytes[offset_top_tree..offset_left_bottom])?; + let left_bottom_tree = + HashSubTree::::from_ssz_bytes(&bytes[offset_left_bottom..offset_right_bottom])?; + let right_bottom_tree = HashSubTree::::from_ssz_bytes(&bytes[offset_right_bottom..])?; + + Ok(Self { + prf_key, + parameter, + activation_epoch, + num_active_epochs, + top_tree, + left_bottom_tree_index, + left_bottom_tree, + right_bottom_tree, + _encoding_type: PhantomData, + }) + } +} + impl SignatureSchemeSecretKey for GeneralizedXMSSSecretKey where @@ -87,26 +472,26 @@ where TH::Parameter: Into, { fn get_activation_interval(&self) -> std::ops::Range { - let start = self.activation_epoch as u64; - let end = start + self.num_active_epochs as u64; + let start = self.activation_epoch; + let end = start + self.num_active_epochs; start..end } fn get_prepared_interval(&self) -> std::ops::Range { // the key is prepared for all epochs covered by the left and right bottom tree // and each bottom tree covers exactly 2^{LOG_LIFETIME / 2} leafs - let leafs_per_bottom_tree = 1 << (LOG_LIFETIME / 2); - let start = (self.left_bottom_tree_index * leafs_per_bottom_tree) as u64; - let end = start + (2 * leafs_per_bottom_tree as u64); + let leafs_per_bottom_tree = 1u64 << (LOG_LIFETIME / 2); + let start = self.left_bottom_tree_index * leafs_per_bottom_tree; + let end = start + (2 * leafs_per_bottom_tree); start..end } fn advance_preparation(&mut self) { // First, check if advancing is possible by comparing to activation interval. - let leafs_per_bottom_tree = 1 << (LOG_LIFETIME / 2); + let leafs_per_bottom_tree = 1u64 << (LOG_LIFETIME / 2); let next_prepared_end_epoch = self.left_bottom_tree_index * leafs_per_bottom_tree + 3 * leafs_per_bottom_tree; - if next_prepared_end_epoch as u64 > self.get_activation_interval().end { + if next_prepared_end_epoch > self.get_activation_interval().end { return; } @@ -194,7 +579,7 @@ fn bottom_tree_from_prf_key< const LOG_LIFETIME: usize, >( prf_key: &PRF::Key, - bottom_tree_index: usize, + bottom_tree_index: u64, parameter: &TH::Parameter, ) -> HashSubTree where @@ -202,7 +587,7 @@ where PRF::Randomness: Into, TH::Parameter: Into, { - let leafs_per_bottom_tree = 1 << (LOG_LIFETIME / 2); + let leafs_per_bottom_tree = 1u64 << (LOG_LIFETIME / 2); let num_chains = IE::DIMENSION; let chain_length = IE::BASE; @@ -219,7 +604,7 @@ where // now that we have the hashes of all chain ends (= leafs of our tree), we can compute the bottom tree HashSubTree::new_bottom_tree( LOG_LIFETIME, - bottom_tree_index, + bottom_tree_index as usize, parameter, chain_ends_hashes, ) @@ -295,7 +680,7 @@ where // leafs of our bottom trees. This is done in `bottom_tree_from_prf_key`. let mut roots_of_bottom_trees = Vec::with_capacity(num_bottom_trees); - let left_bottom_tree_index = start_bottom_tree_index; + let left_bottom_tree_index = start_bottom_tree_index as u64; let left_bottom_tree = bottom_tree_from_prf_key::( &prf_key, left_bottom_tree_index, @@ -303,7 +688,7 @@ where ); roots_of_bottom_trees.push(left_bottom_tree.root()); - let right_bottom_tree_index = start_bottom_tree_index + 1; + let right_bottom_tree_index = (start_bottom_tree_index + 1) as u64; let right_bottom_tree = bottom_tree_from_prf_key::( &prf_key, right_bottom_tree_index, @@ -318,7 +703,7 @@ where .map(|bottom_tree_index| { let bottom_tree = bottom_tree_from_prf_key::( &prf_key, - bottom_tree_index, + bottom_tree_index as u64, ¶meter, ); bottom_tree.root() @@ -342,8 +727,8 @@ where let sk = GeneralizedXMSSSecretKey { prf_key, parameter, - activation_epoch, - num_active_epochs, + activation_epoch: activation_epoch as u64, + num_active_epochs: num_active_epochs as u64, top_tree, left_bottom_tree_index, left_bottom_tree, @@ -375,7 +760,7 @@ where // first component of the signature is the Merkle path that // opens the one-time pk for that epoch, where the one-time pk // will be recomputed by the verifier from the signature. - let leafs_per_bottom_tree = 1 << (LOG_LIFETIME / 2); + let leafs_per_bottom_tree = 1u64 << (LOG_LIFETIME / 2); let boundary_between_bottom_trees = (sk.left_bottom_tree_index * leafs_per_bottom_tree + leafs_per_bottom_tree) as u32; let bottom_tree = if epoch < boundary_between_bottom_trees { @@ -574,6 +959,18 @@ impl Decode for GeneralizedXMSSPublicKey { } } +impl Serializable for GeneralizedXMSSPublicKey {} + +impl Serializable + for GeneralizedXMSSSignature +{ +} + +impl + Serializable for GeneralizedXMSSSecretKey +{ +} + /// Instantiations of the generalized XMSS signature scheme based on Poseidon2 pub mod instantiations_poseidon; /// Instantiations of the generalized XMSS signature scheme based on the @@ -583,7 +980,6 @@ pub mod instantiations_poseidon_top_level; #[cfg(test)] mod tests { use crate::{ - array::FieldArray, inc_encoding::target_sum::TargetSumEncoding, signature::test_templates::test_signature_scheme_correctness, symmetric::{ @@ -596,7 +992,7 @@ mod tests { use super::*; use crate::{F, symmetric::tweak_hash::poseidon::PoseidonTweakHash}; - use p3_field::PrimeCharacteristicRing; + use p3_field::RawDataSerializable; use rand::rng; use ssz::{Decode, Encode}; @@ -738,82 +1134,103 @@ mod tests { } #[test] - fn test_public_key_ssz_roundtrip() { - let mut rng = rng(); - let root = TestTH::rand_domain(&mut rng); - let parameter = TestTH::rand_parameter(&mut rng); - - let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; - - // Encode to SSZ - let encoded = public_key.as_ssz_bytes(); - - // Check expected size: (7 + 5) * 4 = 48 bytes - assert_eq!(encoded.len(), 48); - - // Decode from SSZ - let decoded = - GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); - - // Check fields match - assert_eq!(public_key.root, decoded.root); - assert_eq!(public_key.parameter, decoded.parameter); - } + fn test_ssz_encoding_structure() { + type PRF = ShakePRFtoF<7, 5>; + type TH = PoseidonTweakW1L5; + type MH = PoseidonMessageHashW1; + const BASE: usize = MH::BASE; + const NUM_CHUNKS: usize = MH::DIMENSION; + const MAX_CHUNK_VALUE: usize = BASE - 1; + const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2; + type IE = TargetSumEncoding; + const LOG_LIFETIME: usize = 6; + type Sig = GeneralizedXMSSSignatureScheme; - #[test] - fn test_public_key_ssz_deterministic() { let mut rng = rng(); + + // Test PublicKey encoding structure let root = TestTH::rand_domain(&mut rng); let parameter = TestTH::rand_parameter(&mut rng); - - let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; - - // Encode multiple times - let encoded1 = public_key.as_ssz_bytes(); - let encoded2 = public_key.as_ssz_bytes(); - - // Should be identical - assert_eq!(encoded1, encoded2); - } - - #[test] - fn test_public_key_ssz_zero_values() { - let root = FieldArray([F::ZERO; 7]); - let parameter = FieldArray([F::ZERO; 5]); - let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; - + // Serialize to bytes let encoded = public_key.as_ssz_bytes(); - let decoded = - GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); - + // Verify expected size based on field element counts + assert_eq!(encoded.len(), (7 + 5) * F::NUM_BYTES); + // Verify first field element is encoded correctly + let first_fe_bytes = root.as_ssz_bytes(); + assert_eq!(&encoded[0..F::NUM_BYTES], &first_fe_bytes[0..F::NUM_BYTES]); + // Decode and verify roundtrip + let decoded = GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).unwrap(); assert_eq!(public_key.root, decoded.root); assert_eq!(public_key.parameter, decoded.parameter); - } - #[test] - fn test_public_key_ssz_max_values() { - use p3_field::PrimeField32; - - let max_val = F::ORDER_U32 - 1; - let root = FieldArray([F::new(max_val); 7]); - let parameter = FieldArray([F::new(max_val); 5]); - - let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; - - let encoded = public_key.as_ssz_bytes(); - let decoded = - GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); - - assert_eq!(public_key.root, decoded.root); - assert_eq!(public_key.parameter, decoded.parameter); + // Test Signature encoding structure + let (pk, sk) = Sig::key_gen(&mut rng, 0, 1 << LOG_LIFETIME); + let message = rng.random(); + let epoch = 5; + // Generate valid signature + let signature = Sig::sign(&sk, epoch, &message).unwrap(); + // Serialize to bytes + let sig_encoded = signature.as_ssz_bytes(); + // Calculate randomness size + let rho_size = signature.rho.ssz_bytes_len(); + // Verify minimum size includes two offsets plus fixed field + assert!(sig_encoded.len() >= 4 + rho_size + 4); + // Read first offset value from bytes 0-4 + let offset_path = u32::from_le_bytes(sig_encoded[0..4].try_into().unwrap()) as usize; + // Verify first offset points to end of fixed part + assert_eq!(offset_path, 4 + rho_size + 4); + // Decode and verify signature still validates + let sig_decoded = + ::Signature::from_ssz_bytes(&sig_encoded).unwrap(); + assert!(Sig::verify(&pk, epoch, &message, &sig_decoded)); + + // Test SecretKey encoding structure + let (_pk2, sk2) = Sig::key_gen(&mut rng, 0, 8); + // Serialize secret key to bytes + let sk_encoded = sk2.as_ssz_bytes(); + // Calculate fixed field sizes + let prf_key_size = sk2.prf_key.ssz_bytes_len(); + let param_size = sk2.parameter.ssz_bytes_len(); + let fixed_part_size = prf_key_size + param_size + 8 + 8 + 4 + 8 + 4 + 4; + // Verify minimum size includes all fixed fields + assert!(sk_encoded.len() >= fixed_part_size); + // Read activation epoch value from fixed position + let activation_start = prf_key_size + param_size; + let activation_epoch = u64::from_le_bytes( + sk_encoded[activation_start..activation_start + 8] + .try_into() + .unwrap(), + ); + // Verify stored value matches original + assert_eq!(activation_epoch, sk2.activation_epoch); + // Decode and verify roundtrip by re-encoding + let sk_decoded = ::SecretKey::from_ssz_bytes(&sk_encoded).unwrap(); + let sk_reencoded = sk_decoded.as_ssz_bytes(); + assert_eq!(sk_encoded, sk_reencoded); } #[test] - fn test_public_key_ssz_invalid_length_too_short() { - let bytes = vec![0u8; 47]; // Should be 48 bytes - let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); - assert!(result.is_err()); + fn test_ssz_decoding_errors() { + type PRF = ShakePRFtoF<7, 5>; + type TH = PoseidonTweakW1L5; + type MH = PoseidonMessageHashW1; + const BASE: usize = MH::BASE; + const NUM_CHUNKS: usize = MH::DIMENSION; + const MAX_CHUNK_VALUE: usize = BASE - 1; + const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2; + type IE = TargetSumEncoding; + const LOG_LIFETIME: usize = 6; + type Sig = GeneralizedXMSSSignatureScheme; + + // PublicKey: buffer too small + // TestTH = PoseidonTweakW1L5 has FieldArray<7> hash and FieldArray<5> domain + // Total size: (7 + 5) * F::NUM_BYTES = 12 * 4 = 48 bytes + // Create buffer with only 47 bytes (one byte short) + let encoded = vec![0u8; 47]; + // Attempt decode with insufficient bytes + let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded); + // Decoder reports actual buffer size (47) vs expected (48) assert!(matches!( result, Err(DecodeError::InvalidByteLength { @@ -821,60 +1238,302 @@ mod tests { expected: 48 }) )); - } - #[test] - fn test_public_key_ssz_invalid_length_too_long() { - let bytes = vec![0u8; 49]; // Should be 48 bytes - let result = GeneralizedXMSSPublicKey::::from_ssz_bytes(&bytes); - assert!(result.is_err()); + // Signature: buffer too small - only 8 bytes when we need more + // IE::Randomness = MH::Randomness = FieldArray<5> (from PoseidonMessageHashW1) + // FieldArray<5> has ssz_fixed_len() = 5 * F::NUM_BYTES = 5 * 4 = 20 bytes + // Minimum size: offset (4) + rho (20) + offset (4) = 28 bytes + let encoded = vec![0u8; 8]; + let result = ::Signature::from_ssz_bytes(&encoded); + // Decoder checks min_size at line 119: reports actual (8) vs expected (28) assert!(matches!( result, Err(DecodeError::InvalidByteLength { - len: 49, - expected: 48 + len: 8, + expected: 28 + }) + )); + + // Signature: invalid offset value pointing to wrong location + // Create buffer with sufficient space (28 + 100 bytes) + let mut encoded = vec![0u8; 128]; + // Write incorrect offset (99) that doesn't match expected first offset (28) + encoded[0..4].copy_from_slice(&99u32.to_le_bytes()); + // Write valid rho data at bytes 4..24 (20 bytes of zeros is valid FieldArray<5>) + for i in 0..20 { + encoded[4 + i] = 0; + } + // Write second offset at position 24..28 (actual value doesn't matter) + encoded[24..28].copy_from_slice(&78u32.to_le_bytes()); + // Attempt decode with invalid first offset + let result = ::Signature::from_ssz_bytes(&encoded); + // Decoder at line 149 checks: offset_path (99) != expected_offset_path (28) + // Expected offset points to byte immediately after fixed part: 4 + 20 + 4 = 28 + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { + len: 99, + expected: 28 }) )); } #[test] - fn test_public_key_ssz_fixed_len_trait() { - assert!( as Encode>::is_ssz_fixed_len()); - assert_eq!( - as Encode>::ssz_fixed_len(), - 48 - ); + #[allow(clippy::items_after_statements)] + fn test_ssz_panic_safety_malicious_offsets() { + type PRF = ShakePRFtoF<7, 5>; + type TH = PoseidonTweakW1L5; + type MH = PoseidonMessageHashW1; + const BASE: usize = MH::BASE; + const NUM_CHUNKS: usize = MH::DIMENSION; + const MAX_CHUNK_VALUE: usize = BASE - 1; + const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2; + type IE = TargetSumEncoding; + const LOG_LIFETIME: usize = 6; + type Sig = GeneralizedXMSSSignatureScheme; + + // Helper: Dynamic Size Calculation + // + // We calculate sizes dynamically to avoid hardcoded mismatch errors. + let mut rng = rand::rng(); + + // Generate dummy objects to measure their SSZ encoded length + let dummy_prf_key = PRF::key_gen(&mut rng); + let dummy_param = TH::rand_parameter(&mut rng); + + let prf_key_size = dummy_prf_key.ssz_bytes_len(); + let param_size = dummy_param.ssz_bytes_len(); + let u64_size = 8; + let offset_size = 4; + + // Calculate the exact size of the "Fixed Part" of the SecretKey container. + // + // Layout: [PRF] [Param] [ActEpoch] [NumActive] [OffTop] [LeftIdx] [OffLeft] [OffRight] + let fixed_part_len = prf_key_size + + param_size + + u64_size // activation_epoch + + u64_size // num_active_epochs + + offset_size // offset_top_tree + + u64_size // left_bottom_tree_index + + offset_size // offset_left_bottom + + offset_size; // offset_right_bottom + + // Helper: Error Verifier + fn assert_bytes_invalid(result: Result, expected_msg_part: &str) { + match result { + Err(DecodeError::BytesInvalid(msg)) => { + assert!( + msg.contains(expected_msg_part), + "Error message '{}' did not contain expected part '{}'", + msg, + expected_msg_part + ); + } + Err(e) => panic!("Wrong error type. Expected BytesInvalid, got {:?}", e), + Ok(_) => panic!("Should have failed with BytesInvalid, but succeeded"), + } + } + + // SCENARIO 1: Signature with Reversed Offsets (Non-Monotonic) + // + // - Structure: GeneralizedXMSSSignature { path, rho, hashes } + // - SSZ Layout: [Offset Path (4)] | [Rho (Var)] | [Offset Hashes (4)] | ... + // - Malicious Input: offset_hashes < offset_path + { + let dummy_rho = IE::rand(&mut rng); + let rho_size = dummy_rho.ssz_bytes_len(); + + // Fixed part = Offset(4) + Rho + Offset(4) + let sig_fixed_part_size = 4 + rho_size + 4; + let mut encoded = vec![0u8; 200]; // Sufficient buffer + + // 1. Write [Offset Path] -> Correctly points to end of fixed part + encoded[0..4].copy_from_slice(&(sig_fixed_part_size as u32).to_le_bytes()); + + // 2. Write [Rho] -> Write valid dummy data + let mut rho_buf = Vec::new(); + dummy_rho.ssz_append(&mut rho_buf); + encoded[4..4 + rho_size].copy_from_slice(&rho_buf); + + // 3. Write [Offset Hashes] -> MALICIOUS! + // We set it to 10, which is less than `offset_path` (sig_fixed_part_size). + // This implies the `path` field has negative length, which causes panic if unchecked. + let offset_hashes_pos = 4 + rho_size; + encoded[offset_hashes_pos..offset_hashes_pos + 4].copy_from_slice(&10u32.to_le_bytes()); + + let result = ::Signature::from_ssz_bytes(&encoded); + assert_bytes_invalid(result, "Invalid variable offsets"); + } + + // SCENARIO 2: Signature with Offset Out of Bounds + // + // Malicious Input: offset_hashes points outside the buffer + { + let dummy_rho = IE::rand(&mut rng); + let rho_size = dummy_rho.ssz_bytes_len(); + let sig_fixed_part_size = 4 + rho_size + 4; + + let mut encoded = vec![0u8; 100]; // Buffer length is 100 + + // 1. Write [Offset Path] -> Correct + encoded[0..4].copy_from_slice(&(sig_fixed_part_size as u32).to_le_bytes()); + + // 2. Write [Rho] -> Correct + let mut rho_buf = Vec::new(); + dummy_rho.ssz_append(&mut rho_buf); + encoded[4..4 + rho_size].copy_from_slice(&rho_buf); + + // 3. Write [Offset Hashes] -> MALICIOUS! + // Set to 200, which is > encoded.len() (100). + let offset_hashes_pos = 4 + rho_size; + encoded[offset_hashes_pos..offset_hashes_pos + 4] + .copy_from_slice(&200u32.to_le_bytes()); + + let result = ::Signature::from_ssz_bytes(&encoded); + assert_bytes_invalid(result, "len=100"); + } + + // SCENARIO 3: Secret Key with Interleaved Offset Violation + // + // Structure: Fixed Fields interleaved with 3 Variable Offsets (top, left, right) + // Malicious Input: offset_left < offset_top (Reversed variable sections) + { + let mut encoded = vec![0u8; fixed_part_len + 100]; + let mut pos = 0; + + // 1. Write Fixed Fields: PRF Key + // We write actual valid PRF key bytes + let mut prf_buf = Vec::new(); + dummy_prf_key.ssz_append(&mut prf_buf); + encoded[pos..pos + prf_key_size].copy_from_slice(&prf_buf); + pos += prf_key_size; + + // 2. Write Fixed Fields: Parameter + let mut param_buf = Vec::new(); + dummy_param.ssz_append(&mut param_buf); + encoded[pos..pos + param_size].copy_from_slice(¶m_buf); + pos += param_size; + + // 3. Write Fixed Fields: Activation Epoch (u64) + pos += 8; + + // 4. Write Fixed Fields: Num Active Epochs (u64) + pos += 8; + + // 5. Write [Offset Top Tree] + // Should point to the end of the fixed part. + encoded[pos..pos + 4].copy_from_slice(&(fixed_part_len as u32).to_le_bytes()); + pos += 4; + + // 6. Write Fixed Fields: Left Bottom Tree Index (u64) + pos += 8; + + // 7. Write [Offset Left Bottom Tree] -> MALICIOUS! + // We set it to 10. + // Since 10 < fixed_part_len, this offset comes *before* the Top Tree offset. + // This would cause `bytes[offset_top..offset_left]` to panic. + encoded[pos..pos + 4].copy_from_slice(&10u32.to_le_bytes()); + pos += 4; + + // 8. Write [Offset Right Bottom Tree] + // Set to valid relative location to ensure we don't fail on the third offset check first. + encoded[pos..pos + 4].copy_from_slice(&((fixed_part_len + 50) as u32).to_le_bytes()); + + let result = ::SecretKey::from_ssz_bytes(&encoded); + assert_bytes_invalid(result, "Invalid variable offsets"); + } } #[test] - fn test_public_key_ssz_specific_values() { - // Test with specific known values to verify byte ordering - let root = FieldArray([ - F::new(1), - F::new(2), - F::new(3), - F::new(4), - F::new(5), - F::new(6), - F::new(7), - ]); - let parameter = FieldArray([F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)]); + fn test_ssz_determinism() { + type PRF = ShakePRFtoF<7, 5>; + type TH = PoseidonTweakW1L5; + type MH = PoseidonMessageHashW1; + const BASE: usize = MH::BASE; + const NUM_CHUNKS: usize = MH::DIMENSION; + const MAX_CHUNK_VALUE: usize = BASE - 1; + const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2; + type IE = TargetSumEncoding; + const LOG_LIFETIME: usize = 6; + type Sig = GeneralizedXMSSSignatureScheme; + + let mut rng = rng(); + // PublicKey: encode same structure twice + let root = TestTH::rand_domain(&mut rng); + let parameter = TestTH::rand_parameter(&mut rng); let public_key = GeneralizedXMSSPublicKey:: { root, parameter }; + // Serialize twice to verify deterministic output + let encoded1 = public_key.as_ssz_bytes(); + let encoded2 = public_key.as_ssz_bytes(); + // Verify byte-for-byte identical encoding + assert_eq!(encoded1, encoded2); - let encoded = public_key.as_ssz_bytes(); + // Signature: encode same structure twice + let (_pk, sk) = Sig::key_gen(&mut rng, 0, 1 << LOG_LIFETIME); + let message = rng.random(); + let epoch = 5; + let signature = Sig::sign(&sk, epoch, &message).unwrap(); + // Serialize twice to verify deterministic output + let sig_encoded1 = signature.as_ssz_bytes(); + let sig_encoded2 = signature.as_ssz_bytes(); + // Verify byte-for-byte identical encoding + assert_eq!(sig_encoded1, sig_encoded2); + + // SecretKey: encode same structure twice + let (_pk2, sk2) = Sig::key_gen(&mut rng, 0, 8); + // Serialize twice to verify deterministic output + let sk_encoded1 = sk2.as_ssz_bytes(); + let sk_encoded2 = sk2.as_ssz_bytes(); + // Verify byte-for-byte identical encoding + assert_eq!(sk_encoded1, sk_encoded2); + } - // Check first few bytes (little-endian encoding of 1) - assert_eq!(&encoded[0..4], &[1, 0, 0, 0]); - // Check encoding of 2 - assert_eq!(&encoded[4..8], &[2, 0, 0, 0]); - // Check encoding of 10 (first parameter value) - assert_eq!(&encoded[28..32], &[10, 0, 0, 0]); + #[test] + fn test_ssz_signature_integration() { + type PRF = ShakePRFtoF<7, 5>; + type TH = PoseidonTweakW1L5; + type MH = PoseidonMessageHashW1; + const BASE: usize = MH::BASE; + const NUM_CHUNKS: usize = MH::DIMENSION; + const MAX_CHUNK_VALUE: usize = BASE - 1; + const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2; + type IE = TargetSumEncoding; + const LOG_LIFETIME: usize = 6; + type Sig = GeneralizedXMSSSignatureScheme; - let decoded = - GeneralizedXMSSPublicKey::::from_ssz_bytes(&encoded).expect("Decoding failed"); + let mut rng = rng(); - assert_eq!(public_key.root, decoded.root); - assert_eq!(public_key.parameter, decoded.parameter); + // Generate keypair and sign message + let (pk, sk) = Sig::key_gen(&mut rng, 0, 1 << LOG_LIFETIME); + let message = rng.random(); + let epoch = 7; + // Create valid signature + let signature = Sig::sign(&sk, epoch, &message).unwrap(); + // Verify signature is valid before serialization + assert!(Sig::verify(&pk, epoch, &message, &signature)); + + // Test PublicKey serialization + let pk_encoded = pk.as_ssz_bytes(); + let pk_decoded = GeneralizedXMSSPublicKey::::from_ssz_bytes(&pk_encoded).unwrap(); + // Verify decoded key can still verify signature + assert!(Sig::verify(&pk_decoded, epoch, &message, &signature)); + + // Test Signature serialization + let sig_encoded = signature.as_ssz_bytes(); + let sig_decoded = + ::Signature::from_ssz_bytes(&sig_encoded).unwrap(); + // Verify decoded signature still validates with original key + assert!(Sig::verify(&pk, epoch, &message, &sig_decoded)); + // Verify decoded signature validates with decoded key + assert!(Sig::verify(&pk_decoded, epoch, &message, &sig_decoded)); + + // Test SecretKey serialization + let sk_encoded = sk.as_ssz_bytes(); + let sk_decoded = ::SecretKey::from_ssz_bytes(&sk_encoded).unwrap(); + // Sign with decoded key + let sig2 = Sig::sign(&sk_decoded, epoch + 1, &message).unwrap(); + // Verify signature from decoded key validates + assert!(Sig::verify(&pk, epoch + 1, &message, &sig2)); } } diff --git a/src/symmetric/message_hash.rs b/src/symmetric/message_hash.rs index ab2f766..0b78edd 100644 --- a/src/symmetric/message_hash.rs +++ b/src/symmetric/message_hash.rs @@ -1,7 +1,7 @@ use rand::Rng; -use serde::{Serialize, de::DeserializeOwned}; use crate::MESSAGE_LENGTH; +use crate::serialization::Serializable; /// Trait to model a hash function used for message hashing. /// @@ -12,8 +12,8 @@ use crate::MESSAGE_LENGTH; /// /// Note that BASE must be at most 2^8, as we encode chunks as u8. pub trait MessageHash { - type Parameter: Clone + Sized + Serialize + DeserializeOwned; - type Randomness: Serialize + DeserializeOwned; + type Parameter: Clone + Serializable; + type Randomness: Serializable; /// number of entries in a hash const DIMENSION: usize; diff --git a/src/symmetric/message_hash/poseidon.rs b/src/symmetric/message_hash/poseidon.rs index 2e7fd99..01e0503 100644 --- a/src/symmetric/message_hash/poseidon.rs +++ b/src/symmetric/message_hash/poseidon.rs @@ -8,6 +8,7 @@ use super::MessageHash; use crate::F; use crate::MESSAGE_LENGTH; use crate::TWEAK_SEPARATOR_FOR_MESSAGE_HASH; +use crate::array::FieldArray; use crate::poseidon2_24; use crate::symmetric::tweak_hash::poseidon::poseidon_compress; @@ -130,16 +131,16 @@ where [F; PARAMETER_LEN]: Serialize + DeserializeOwned, [F; RAND_LEN_FE]: Serialize + DeserializeOwned, { - type Parameter = [F; PARAMETER_LEN]; + type Parameter = FieldArray; - type Randomness = [F; RAND_LEN_FE]; + type Randomness = FieldArray; const DIMENSION: usize = DIMENSION; const BASE: usize = BASE; fn rand(rng: &mut R) -> Self::Randomness { - rng.random() + FieldArray(rng.random()) } fn apply( @@ -238,7 +239,7 @@ mod tests { fn test_apply() { let mut rng = rand::rng(); - let parameter = rng.random(); + let parameter = FieldArray(rng.random()); let message = rng.random(); @@ -253,7 +254,7 @@ mod tests { fn test_apply_w1() { let mut rng = rand::rng(); - let parameter = rng.random(); + let parameter = FieldArray(rng.random()); let message = rng.random(); diff --git a/src/symmetric/message_hash/top_level_poseidon.rs b/src/symmetric/message_hash/top_level_poseidon.rs index 6ea7998..e4c99c9 100644 --- a/src/symmetric/message_hash/top_level_poseidon.rs +++ b/src/symmetric/message_hash/top_level_poseidon.rs @@ -9,6 +9,7 @@ use super::poseidon::encode_epoch; use super::poseidon::encode_message; use crate::F; use crate::MESSAGE_LENGTH; +use crate::array::FieldArray; use crate::hypercube::hypercube_find_layer; use crate::hypercube::hypercube_part_size; use crate::hypercube::map_to_vertex; @@ -118,16 +119,16 @@ where [F; PARAMETER_LEN]: Serialize + DeserializeOwned, [F; RAND_LEN]: Serialize + DeserializeOwned, { - type Parameter = [F; PARAMETER_LEN]; + type Parameter = FieldArray; - type Randomness = [F; RAND_LEN]; + type Randomness = FieldArray; const DIMENSION: usize = DIMENSION; const BASE: usize = BASE; fn rand(rng: &mut R) -> Self::Randomness { - rng.random() + FieldArray(rng.random()) } fn apply( @@ -256,7 +257,7 @@ mod tests { let mut rng = rand::rng(); - let parameter = rng.random(); + let parameter = FieldArray(rng.random()); let message = rng.random(); @@ -294,7 +295,7 @@ mod tests { let mut rng = rand::rng(); - let parameter = rng.random(); + let parameter = FieldArray(rng.random()); let randomness = MH::rand(&mut rng); let hash = MH::apply(¶meter, epoch, &randomness, &message); diff --git a/src/symmetric/prf.rs b/src/symmetric/prf.rs index 4297ca3..50504a6 100644 --- a/src/symmetric/prf.rs +++ b/src/symmetric/prf.rs @@ -1,11 +1,12 @@ use rand::Rng; -use serde::{Serialize, de::DeserializeOwned}; + +use crate::serialization::Serializable; use crate::MESSAGE_LENGTH; /// Trait to model a pseudorandom function (PRF) pub trait Pseudorandom { - type Key: Send + Sync + Serialize + DeserializeOwned; + type Key: Send + Sync + Serializable; type Domain; type Randomness; diff --git a/src/symmetric/tweak_hash.rs b/src/symmetric/tweak_hash.rs index b12fbe1..efbb90d 100644 --- a/src/symmetric/tweak_hash.rs +++ b/src/symmetric/tweak_hash.rs @@ -1,7 +1,6 @@ use rand::Rng; -use serde::{Serialize, de::DeserializeOwned}; -use ssz::{Decode, Encode}; +use crate::serialization::Serializable; use crate::symmetric::prf::Pseudorandom; /// Trait to model a tweakable hash function. @@ -18,21 +17,13 @@ use crate::symmetric::prf::Pseudorandom; /// applications in Merkle trees. pub trait TweakableHash { /// Public parameter type for the hash function - type Parameter: Copy + Sized + Send + Sync + Serialize + DeserializeOwned + Encode + Decode; + type Parameter: Copy + Send + Sync + Serializable; /// Tweak type for domain separation type Tweak; /// Domain element type (defines output and input types to the hash) - type Domain: Copy - + PartialEq - + Sized - + Send - + Sync - + Serialize - + DeserializeOwned - + Encode - + Decode; + type Domain: Copy + PartialEq + Send + Sync + Serializable; /// Generates a random public parameter. fn rand_parameter(rng: &mut R) -> Self::Parameter; diff --git a/src/symmetric/tweak_hash_tree.rs b/src/symmetric/tweak_hash_tree.rs index a02f5fa..74ef9ad 100644 --- a/src/symmetric/tweak_hash_tree.rs +++ b/src/symmetric/tweak_hash_tree.rs @@ -1,18 +1,90 @@ +use crate::serialization::Serializable; use crate::symmetric::tweak_hash::TweakableHash; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use ssz::{Decode, DecodeError, Encode}; /// A single layer of a sparse Hash-Tree /// based on tweakable hash function -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] struct HashTreeLayer { - start_index: usize, + start_index: u64, nodes: Vec, } +impl Encode for HashTreeLayer { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + // - Fixed part: start_index (8 bytes) + offset (4 bytes) + // - Variable part: nodes + 8 + 4 + self.nodes.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + // SSZ Container encoding order: + // 1. Fixed field: start_index + self.start_index.ssz_append(buf); + + // 2. Offset for variable field: nodes + // Offset points to where variable data starts = end of fixed part + // 8 bytes (start_index) + 4 bytes (offset itself) + let offset: u32 = 12; + buf.extend_from_slice(&offset.to_le_bytes()); + + // 3. Variable data: nodes + self.nodes.ssz_append(buf); + } +} + +impl Decode for HashTreeLayer { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + // Minimum size: start_index (8) + offset (4) = 12 bytes + const FIXED_SIZE: usize = 12; + if bytes.len() < FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: FIXED_SIZE, + }); + } + + // 1. Decode fixed field: start_index + let start_index = u64::from_ssz_bytes(&bytes[0..8])?; + + // 2. Read offset for variable field + let offset = u32::from_le_bytes(bytes[8..12].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: 12, + } + })?) as usize; + + // 3. Validate offset points to end of fixed part + if offset != FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: offset, + expected: FIXED_SIZE, + }); + } + + // 4. Decode variable field: nodes + let nodes = Vec::::from_ssz_bytes(&bytes[offset..])?; + + Ok(Self { start_index, nodes }) + } +} + +impl Serializable for HashTreeLayer {} + impl HashTreeLayer { /// Construct a layer from a contiguous run of nodes and pad it so that: /// - the layer starts at an even index (a left child), and @@ -64,7 +136,7 @@ impl HashTreeLayer { // Return the padded layer with the corrected start index. Self { - start_index: actual_start_index, + start_index: actual_start_index as u64, nodes: out, } } @@ -88,11 +160,11 @@ pub struct HashSubTree { /// Depth of the full tree. The tree can have at most /// 1 << depth many leafs. The full tree has depth + 1 /// many layers, whereas the sub-tree can have less. - depth: usize, + depth: u64, /// The lowest layer of the sub-tree. If this represents the /// full tree, then lowest_layer = 0. - lowest_layer: usize, + lowest_layer: u64, /// Layers of the hash tree, starting with the /// lowest_level. That is, layers[i] contains the nodes @@ -102,6 +174,82 @@ pub struct HashSubTree { layers: Vec>, } +impl Encode for HashSubTree { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + // - Fixed part: depth (8) + lowest_layer (8) + offset (4) + // - Variable part: layers + 8 + 8 + 4 + self.layers.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + // SSZ Container encoding order: + // 1. Fixed field: depth + self.depth.ssz_append(buf); + + // 2. Fixed field: lowest_layer + self.lowest_layer.ssz_append(buf); + + // 3. Offset for variable field: layers + let offset: u32 = 20; // 8 (depth) + 8 (lowest_layer) + 4 (offset itself) + buf.extend_from_slice(&offset.to_le_bytes()); + + // 4. Variable data: layers + self.layers.ssz_append(buf); + } +} + +impl Decode for HashSubTree { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + // Minimum size: depth (8) + lowest_layer (8) + offset (4) = 20 bytes + const FIXED_SIZE: usize = 20; + if bytes.len() < FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: FIXED_SIZE, + }); + } + + // 1. Decode fixed field: depth + let depth = u64::from_ssz_bytes(&bytes[0..8])?; + + // 2. Decode fixed field: lowest_layer + let lowest_layer = u64::from_ssz_bytes(&bytes[8..16])?; + + // 3. Read offset for variable field + let offset = u32::from_le_bytes(bytes[16..20].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: 20, + } + })?) as usize; + + // 4. Validate offset points to end of fixed part + if offset != FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: offset, + expected: FIXED_SIZE, + }); + } + + // 5. Decode variable field: layers + let layers = Vec::>::from_ssz_bytes(&bytes[offset..])?; + + Ok(Self { + depth, + lowest_layer, + layers, + }) + } +} + /// Opening in a hash-tree: a co-path, without the leaf #[derive(Serialize, Deserialize)] #[serde(bound = "")] @@ -112,6 +260,71 @@ pub struct HashTreeOpening { co_path: Vec, } +impl Encode for HashTreeOpening { + fn is_ssz_fixed_len() -> bool { + false + } + + fn ssz_bytes_len(&self) -> usize { + // - Fixed part: offset (4 bytes) + // - Variable part: co_path + 4 + self.co_path.ssz_bytes_len() + } + + fn ssz_append(&self, buf: &mut Vec) { + // SSZ Container encoding order: + // 1. Offset for variable field: co_path + // Only the offset itself in fixed part + let offset: u32 = 4; + buf.extend_from_slice(&offset.to_le_bytes()); + + // 2. Variable data: co_path + self.co_path.ssz_append(buf); + } +} + +impl Decode for HashTreeOpening { + fn is_ssz_fixed_len() -> bool { + false + } + + fn from_ssz_bytes(bytes: &[u8]) -> Result { + // Minimum size: offset (4 bytes) + const FIXED_SIZE: usize = 4; + if bytes.len() < FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: bytes.len(), + expected: FIXED_SIZE, + }); + } + + // 1. Read offset for variable field + let offset = u32::from_le_bytes(bytes[0..4].try_into().map_err(|_| { + DecodeError::InvalidByteLength { + len: bytes.len(), + expected: 4, + } + })?) as usize; + + // 2. Validate offset points to end of fixed part + if offset != FIXED_SIZE { + return Err(DecodeError::InvalidByteLength { + len: offset, + expected: FIXED_SIZE, + }); + } + + // 3. Decode variable field: co_path + let co_path = Vec::::from_ssz_bytes(&bytes[offset..])?; + + Ok(Self { co_path }) + } +} + +impl Serializable for HashTreeOpening {} + +impl Serializable for HashSubTree {} + impl HashSubTree where TH: TweakableHash, @@ -173,7 +386,7 @@ where let prev = &layers[level - lowest_layer]; // Parent layer starts at half the previous start index - let parent_start = prev.start_index >> 1; + let parent_start = (prev.start_index >> 1) as usize; // Compute all parents in parallel, pairing children two-by-two // @@ -199,8 +412,8 @@ where } Self { - depth, - lowest_layer, + depth: depth as u64, + lowest_layer: lowest_layer as u64, layers, } } @@ -287,7 +500,7 @@ where let bottom_tree_root = bottom_tree.layers[depth / 2].nodes[bottom_tree_index % 2]; bottom_tree.layers.truncate(depth / 2); bottom_tree.layers.push(HashTreeLayer { - start_index: bottom_tree_index, + start_index: bottom_tree_index as u64, nodes: vec![bottom_tree_root], }); @@ -316,29 +529,28 @@ where "Hash-Tree path: Need at least one layer" ); assert!( - (position as u64) >= (self.layers[0].start_index as u64), + (position as u64) >= self.layers[0].start_index, "Hash-Tree path: Invalid position, position before start index" ); assert!( - (position as u64) - < (self.layers[0].start_index as u64 + self.layers[0].nodes.len() as u64), + (position as u64) < self.layers[0].start_index + self.layers[0].nodes.len() as u64, "Hash-Tree path: Invalid position, position too large" ); // in our co-path, we will have one node per layer // except the final layer (which is just the root) - let mut co_path = Vec::with_capacity(self.depth); + let mut co_path = Vec::with_capacity(self.depth as usize); let mut current_position = position; - for l in 0..(self.depth - self.lowest_layer) { + for l in 0..((self.depth - self.lowest_layer) as usize) { // if we are already at the root, we can stop (this is a special case for bottom trees) if self.layers[l].nodes.len() <= 1 { break; } // position of the sibling that we want to include let sibling_position = current_position ^ 0x01; - let sibling_position_in_vec = sibling_position - self.layers[l].start_index as u32; - // add to the co-path - let sibling = self.layers[l].nodes[sibling_position_in_vec as usize]; + let sibling_position_in_vec = + (sibling_position as u64 - self.layers[l].start_index) as usize; + let sibling = self.layers[l].nodes[sibling_position_in_vec]; co_path.push(sibling); // new position in next layer current_position >>= 1; @@ -737,4 +949,323 @@ mod tests { leaf_len, ); } + + #[test] + fn test_ssz_encoding_structure() { + let mut rng = rand::rng(); + + // HashTreeLayer: Generate sample nodes + let nodes: Vec<_> = (0..3).map(|_| TestTH::rand_domain(&mut rng)).collect(); + // Create layer with specific index + let layer = HashTreeLayer:: { + start_index: 256, + nodes, + }; + // Serialize to bytes + let encoded = layer.as_ssz_bytes(); + // Verify minimum size: 8 bytes for index + 4 bytes for offset + assert!(encoded.len() >= 12); + // Verify index value in bytes 0-8 + assert_eq!(u64::from_le_bytes(encoded[0..8].try_into().unwrap()), 256); + // Verify offset value in bytes 8-12 points to byte 12 + assert_eq!(u32::from_le_bytes(encoded[8..12].try_into().unwrap()), 12); + + // HashSubTree: Create minimal tree with no layers + let tree = HashSubTree:: { + depth: 16, + lowest_layer: 8, + layers: vec![], + }; + // Serialize to bytes + let encoded = tree.as_ssz_bytes(); + // Verify minimum size: 8 + 8 + 4 = 20 bytes + assert!(encoded.len() >= 20); + // Verify depth value in bytes 0-8 + assert_eq!(u64::from_le_bytes(encoded[0..8].try_into().unwrap()), 16); + // Verify lowest layer value in bytes 8-16 + assert_eq!(u64::from_le_bytes(encoded[8..16].try_into().unwrap()), 8); + // Verify offset value in bytes 16-20 points to byte 20 + assert_eq!(u32::from_le_bytes(encoded[16..20].try_into().unwrap()), 20); + + // HashTreeOpening: Generate authentication path + let co_path: Vec<_> = (0..5).map(|_| TestTH::rand_domain(&mut rng)).collect(); + // Create opening structure + let opening = HashTreeOpening:: { co_path }; + // Serialize to bytes + let encoded = opening.as_ssz_bytes(); + // Verify minimum size: 4 bytes for offset + assert!(encoded.len() >= 4); + // Verify offset value in bytes 0-4 points to byte 4 + assert_eq!(u32::from_le_bytes(encoded[0..4].try_into().unwrap()), 4); + } + + #[test] + fn test_ssz_decoding_errors() { + // HashTreeLayer: Buffer too small (8 bytes instead of minimum 12) + let encoded = vec![0u8; 8]; + // Attempt decode, expect error + let result = HashTreeLayer::::from_ssz_bytes(&encoded); + assert!(matches!(result, Err(DecodeError::InvalidByteLength { .. }))); + + // HashTreeLayer: Invalid offset value (99 instead of 12) + let mut encoded = vec![0u8; 12]; + // Write zero for index field + encoded[0..8].copy_from_slice(&0u64.to_le_bytes()); + // Write incorrect offset + encoded[8..12].copy_from_slice(&99u32.to_le_bytes()); + // Attempt decode, expect error with expected value 12 + let result = HashTreeLayer::::from_ssz_bytes(&encoded); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { expected: 12, .. }) + )); + + // HashSubTree: Buffer too small (16 bytes instead of minimum 20) + let encoded = vec![0u8; 16]; + let result = HashSubTree::::from_ssz_bytes(&encoded); + assert!(matches!(result, Err(DecodeError::InvalidByteLength { .. }))); + + // HashSubTree: Invalid offset value (100 instead of 20) + let mut encoded = vec![0u8; 20]; + // Write depth field + encoded[0..8].copy_from_slice(&10u64.to_le_bytes()); + // Write lowest layer field + encoded[8..16].copy_from_slice(&5u64.to_le_bytes()); + // Write incorrect offset + encoded[16..20].copy_from_slice(&100u32.to_le_bytes()); + let result = HashSubTree::::from_ssz_bytes(&encoded); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { expected: 20, .. }) + )); + + // HashTreeOpening: Buffer too small (2 bytes instead of minimum 4) + let encoded = vec![0u8; 2]; + let result = HashTreeOpening::::from_ssz_bytes(&encoded); + assert!(matches!(result, Err(DecodeError::InvalidByteLength { .. }))); + + // HashTreeOpening: Invalid offset value (10 instead of 4) + let mut encoded = vec![0u8; 4]; + // Write incorrect offset + encoded[0..4].copy_from_slice(&10u32.to_le_bytes()); + let result = HashTreeOpening::::from_ssz_bytes(&encoded); + assert!(matches!( + result, + Err(DecodeError::InvalidByteLength { expected: 4, .. }) + )); + } + + #[test] + fn test_ssz_determinism() { + let mut rng = rand::rng(); + + // HashTreeLayer: Generate random nodes + let nodes: Vec<_> = (0..7).map(|_| TestTH::rand_domain(&mut rng)).collect(); + // Create structure + let layer = HashTreeLayer:: { + start_index: 999, + nodes, + }; + // Encode twice, verify identical bytes + let encoded1 = layer.as_ssz_bytes(); + let encoded2 = layer.as_ssz_bytes(); + assert_eq!(encoded1, encoded2); + + // HashSubTree: Create tree with one layer + let layer = HashTreeLayer:: { + start_index: 4, + nodes: (0..6).map(|_| TestTH::rand_domain(&mut rng)).collect(), + }; + let tree = HashSubTree:: { + depth: 20, + lowest_layer: 10, + layers: vec![layer], + }; + // Encode twice, verify identical bytes + let encoded1 = tree.as_ssz_bytes(); + let encoded2 = tree.as_ssz_bytes(); + assert_eq!(encoded1, encoded2); + + // HashTreeOpening: Generate random authentication path + let co_path: Vec<_> = (0..15).map(|_| TestTH::rand_domain(&mut rng)).collect(); + let opening = HashTreeOpening:: { co_path }; + // Encode twice, verify identical bytes + let encoded1 = opening.as_ssz_bytes(); + let encoded2 = opening.as_ssz_bytes(); + assert_eq!(encoded1, encoded2); + } + + #[test] + fn test_ssz_merkle_integration() { + let mut rng = rand::rng(); + let parameter = TestTH::rand_parameter(&mut rng); + + // Build tree: 8 leaves at depth 3 + let num_leafs = 8; + let depth = 3; + let start_index = 0; + let leaf_len = 2; + // Generate leaf data + let mut leafs = Vec::new(); + for _ in 0..num_leafs { + let leaf: Vec<_> = (0..leaf_len) + .map(|_| TestTH::rand_domain(&mut rng)) + .collect(); + leafs.push(leaf); + } + // Hash leaves for tree construction + let leafs_hashes: Vec<_> = leafs + .iter() + .enumerate() + .map(|(i, v)| TestTH::apply(¶meter, &TestTH::tree_tweak(0, i as u32), v.as_slice())) + .collect(); + // Build complete merkle tree + let tree = HashSubTree::::new_subtree( + &mut rng, + 0, + depth, + start_index, + ¶meter, + leafs_hashes, + ); + let root = tree.root(); + + // Test tree serialization roundtrip + let tree_encoded = tree.as_ssz_bytes(); + let tree_decoded = HashSubTree::::from_ssz_bytes(&tree_encoded).unwrap(); + // Verify decoded tree has same root + assert_eq!(root, tree_decoded.root()); + + // Test authentication path at position 3 + let position = 3u32; + let path = tree.path(position); + let leaf = &leafs[position as usize]; + + // Test path serialization roundtrip + let path_encoded = path.as_ssz_bytes(); + let path_decoded = HashTreeOpening::::from_ssz_bytes(&path_encoded).unwrap(); + + // Verify decoded path authenticates correctly + assert!(hash_tree_verify( + ¶meter, + &root, + position, + leaf, + &path_decoded + )); + + // Verify path from decoded tree also works + let path_from_decoded = tree_decoded.path(position); + assert!(hash_tree_verify( + ¶meter, + &root, + position, + leaf, + &path_from_decoded + )); + } + + proptest! { + #[test] + fn proptest_hash_tree_layer_ssz_roundtrip( + start_index in 0u64..1000, + num_nodes in 0usize..20, + ) { + // Generate random nodes + let mut rng = rand::rng(); + let nodes: Vec<_> = (0..num_nodes).map(|_| TestTH::rand_domain(&mut rng)).collect(); + // Create layer structure + let layer = HashTreeLayer:: { + start_index, + nodes, + }; + + // Perform serialization roundtrip + let encoded = layer.as_ssz_bytes(); + let decoded = HashTreeLayer::::from_ssz_bytes(&encoded).unwrap(); + + // Verify index field preserved + prop_assert_eq!(layer.start_index, decoded.start_index); + // Verify node count preserved + prop_assert_eq!(layer.nodes.len(), decoded.nodes.len()); + // Verify each node value preserved + for i in 0..layer.nodes.len() { + prop_assert_eq!(layer.nodes[i], decoded.nodes[i]); + } + // Verify determinism by re-encoding + let reencoded = decoded.as_ssz_bytes(); + prop_assert_eq!(encoded, reencoded); + } + + #[test] + fn proptest_hash_sub_tree_ssz_roundtrip( + depth in 1u64..32, + lowest_layer in 0u64..16, + num_layers in 0usize..5, + ) { + // Ensure valid tree configuration + prop_assume!(lowest_layer < depth); + + // Generate random layers + let mut rng = rand::rng(); + let mut layers = Vec::new(); + for _ in 0..num_layers { + let num_nodes = rng.random_range(0..10); + let layer = HashTreeLayer:: { + start_index: rng.random_range(0..100), + nodes: (0..num_nodes).map(|_| TestTH::rand_domain(&mut rng)).collect(), + }; + layers.push(layer); + } + // Create tree structure + let tree = HashSubTree:: { + depth, + lowest_layer, + layers, + }; + + // Perform serialization roundtrip + let encoded = tree.as_ssz_bytes(); + let decoded = HashSubTree::::from_ssz_bytes(&encoded).unwrap(); + + // Verify tree metadata preserved + prop_assert_eq!(tree.depth, decoded.depth); + prop_assert_eq!(tree.lowest_layer, decoded.lowest_layer); + // Verify layer count preserved + prop_assert_eq!(tree.layers.len(), decoded.layers.len()); + // Verify each layer structure preserved + for i in 0..tree.layers.len() { + prop_assert_eq!(tree.layers[i].start_index, decoded.layers[i].start_index); + prop_assert_eq!(tree.layers[i].nodes.len(), decoded.layers[i].nodes.len()); + } + // Verify determinism by re-encoding + let reencoded = decoded.as_ssz_bytes(); + prop_assert_eq!(encoded, reencoded); + } + + #[test] + fn proptest_hash_tree_opening_ssz_roundtrip( + co_path_len in 0usize..64, + ) { + // Generate random authentication path + let mut rng = rand::rng(); + let co_path: Vec<_> = (0..co_path_len).map(|_| TestTH::rand_domain(&mut rng)).collect(); + // Create opening structure + let opening = HashTreeOpening:: { co_path }; + + // Perform serialization roundtrip + let encoded = opening.as_ssz_bytes(); + let decoded = HashTreeOpening::::from_ssz_bytes(&encoded).unwrap(); + + // Verify path length preserved + prop_assert_eq!(opening.co_path.len(), decoded.co_path.len()); + // Verify each path element preserved + for i in 0..opening.co_path.len() { + prop_assert_eq!(opening.co_path[i], decoded.co_path[i]); + } + // Verify determinism by re-encoding + let reencoded = decoded.as_ssz_bytes(); + prop_assert_eq!(encoded, reencoded); + } + } }