Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ dashmap = "6.1.0"
serde = { version = "1.0", features = ["derive", "alloc"] }
thiserror = "2.0"

ssz = { package = "ethereum_ssz", version = "0.10.0" }
ssz_derive = { package = "ethereum_ssz_derive", version = "0.10.0" }

p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
Expand Down
232 changes: 231 additions & 1 deletion src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rayon::prelude::*;
use serde::{Deserialize, Serialize};

use crate::{
MESSAGE_LENGTH,
F, MESSAGE_LENGTH,
inc_encoding::IncomparableEncoding,
signature::SignatureSchemeSecretKey,
symmetric::{
Expand All @@ -17,6 +17,9 @@ use crate::{

use super::{SignatureScheme, SigningError};

use p3_field::{PrimeField32, RawDataSerializable};
use ssz::{Decode, DecodeError, Encode};

/// Implementation of the generalized XMSS signature scheme
/// from any incomparable encoding scheme and any tweakable hash
///
Expand Down Expand Up @@ -525,6 +528,85 @@ where
}
}

impl<TH, const HASH_LEN: usize, const PARAM_LEN: usize> Encode for GeneralizedXMSSPublicKey<TH>
where
TH: TweakableHash<Domain = [F; HASH_LEN], Parameter = [F; PARAM_LEN]>,
{
fn is_ssz_fixed_len() -> bool {
true
}

fn ssz_fixed_len() -> usize {
(HASH_LEN + PARAM_LEN) * F::NUM_BYTES
}

fn ssz_bytes_len(&self) -> usize {
(HASH_LEN + PARAM_LEN) * F::NUM_BYTES
}

fn ssz_append(&self, buf: &mut Vec<u8>) {
// Reserve space for the output
buf.reserve((HASH_LEN + PARAM_LEN) * F::NUM_BYTES);

// Encode root
for elem in &self.root {
let value = elem.as_canonical_u32();
buf.extend_from_slice(&value.to_le_bytes());
}
// Encode parameter
for elem in &self.parameter {
let value = elem.as_canonical_u32();
buf.extend_from_slice(&value.to_le_bytes());
}
}
}

impl<TH, const HASH_LEN: usize, const PARAM_LEN: usize> Decode for GeneralizedXMSSPublicKey<TH>
where
TH: TweakableHash<Domain = [F; HASH_LEN], Parameter = [F; PARAM_LEN]>,
{
fn is_ssz_fixed_len() -> bool {
true
}

fn ssz_fixed_len() -> usize {
(HASH_LEN + PARAM_LEN) * F::NUM_BYTES
}

fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
let expected_len = (HASH_LEN + PARAM_LEN) * F::NUM_BYTES;

if bytes.len() != expected_len {
return Err(DecodeError::InvalidByteLength {
len: bytes.len(),
expected: expected_len,
});
}

// We know this is safe because of the length check above.
let (root_bytes, param_bytes) = bytes.split_at(HASH_LEN * F::NUM_BYTES);

// Define a helper closure to decode an array from bytes.
let decode_array = |chunk: &[u8]| -> F {
let val = u32::from_le_bytes(chunk.try_into().unwrap());
F::new(val)
};

// Construct the root and parameter arrays.
let root = std::array::from_fn(|i| {
let start = i * F::NUM_BYTES;
decode_array(&root_bytes[start..start + F::NUM_BYTES])
});

let parameter = std::array::from_fn(|i| {
let start = i * F::NUM_BYTES;
decode_array(&param_bytes[start..start + F::NUM_BYTES])
});

Ok(Self { root, parameter })
}
}

/// Instantiations of the generalized XMSS signature scheme based on Poseidon2
pub mod instantiations_poseidon;
/// Instantiations of the generalized XMSS signature scheme based on the
Expand All @@ -545,6 +627,13 @@ mod tests {

use super::*;

use crate::symmetric::tweak_hash::poseidon::PoseidonTweakHash;
use p3_field::PrimeCharacteristicRing;
use rand::rng;
use ssz::{Decode, Encode};

type TestTH = PoseidonTweakHash<5, 7, 2, 9, 155>;

#[test]
pub fn test_target_sum_poseidon() {
// Note: do not use these parameters, they are just for testing
Expand Down Expand Up @@ -679,4 +768,145 @@ mod tests {
let (start, end_excl) = expand_activation_time::<LOG_LIFETIME>(12, 2);
assert!((start == 2) && (end_excl == 4));
}

#[test]
fn test_public_key_ssz_roundtrip() {
let mut rng = rng();
let root = TestTH::rand_domain(&mut rng);
let parameter = TestTH::rand_parameter(&mut rng);

let public_key = GeneralizedXMSSPublicKey::<TestTH> { root, parameter };

// Encode to SSZ
let encoded = public_key.as_ssz_bytes();

// Check expected size: (7 + 5) * 4 = 48 bytes
assert_eq!(encoded.len(), 48);

// Decode from SSZ
let decoded =
GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&encoded).expect("Decoding failed");

// Check fields match
assert_eq!(public_key.root, decoded.root);
assert_eq!(public_key.parameter, decoded.parameter);
}

#[test]
fn test_public_key_ssz_deterministic() {
let mut rng = rng();
let root = TestTH::rand_domain(&mut rng);
let parameter = TestTH::rand_parameter(&mut rng);

let public_key = GeneralizedXMSSPublicKey::<TestTH> { root, parameter };

// Encode multiple times
let encoded1 = public_key.as_ssz_bytes();
let encoded2 = public_key.as_ssz_bytes();

// Should be identical
assert_eq!(encoded1, encoded2);
}

#[test]
fn test_public_key_ssz_zero_values() {
let root = [F::ZERO; 7];
let parameter = [F::ZERO; 5];

let public_key = GeneralizedXMSSPublicKey::<TestTH> { root, parameter };

let encoded = public_key.as_ssz_bytes();
let decoded =
GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&encoded).expect("Decoding failed");

assert_eq!(public_key.root, decoded.root);
assert_eq!(public_key.parameter, decoded.parameter);
}

#[test]
fn test_public_key_ssz_max_values() {
use p3_field::PrimeField32;

let max_val = F::ORDER_U32 - 1;
let root = [F::new(max_val); 7];
let parameter = [F::new(max_val); 5];

let public_key = GeneralizedXMSSPublicKey::<TestTH> { root, parameter };

let encoded = public_key.as_ssz_bytes();
let decoded =
GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&encoded).expect("Decoding failed");

assert_eq!(public_key.root, decoded.root);
assert_eq!(public_key.parameter, decoded.parameter);
}

#[test]
fn test_public_key_ssz_invalid_length_too_short() {
let bytes = vec![0u8; 47]; // Should be 48 bytes
let result = GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&bytes);
assert!(result.is_err());
assert!(matches!(
result,
Err(DecodeError::InvalidByteLength {
len: 47,
expected: 48
})
));
}

#[test]
fn test_public_key_ssz_invalid_length_too_long() {
let bytes = vec![0u8; 49]; // Should be 48 bytes
let result = GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&bytes);
assert!(result.is_err());
assert!(matches!(
result,
Err(DecodeError::InvalidByteLength {
len: 49,
expected: 48
})
));
}

#[test]
fn test_public_key_ssz_fixed_len_trait() {
assert!(<GeneralizedXMSSPublicKey::<TestTH> as Encode>::is_ssz_fixed_len());
assert_eq!(
<GeneralizedXMSSPublicKey::<TestTH> as Encode>::ssz_fixed_len(),
48
);
}

#[test]
fn test_public_key_ssz_specific_values() {
// Test with specific known values to verify byte ordering
let root = [
F::new(1),
F::new(2),
F::new(3),
F::new(4),
F::new(5),
F::new(6),
F::new(7),
];
let parameter = [F::new(10), F::new(20), F::new(30), F::new(40), F::new(50)];

let public_key = GeneralizedXMSSPublicKey::<TestTH> { root, parameter };

let encoded = public_key.as_ssz_bytes();

// Check first few bytes (little-endian encoding of 1)
assert_eq!(&encoded[0..4], &[1, 0, 0, 0]);
// Check encoding of 2
assert_eq!(&encoded[4..8], &[2, 0, 0, 0]);
// Check encoding of 10 (first parameter value)
assert_eq!(&encoded[28..32], &[10, 0, 0, 0]);

let decoded =
GeneralizedXMSSPublicKey::<TestTH>::from_ssz_bytes(&encoded).expect("Decoding failed");

assert_eq!(public_key.root, decoded.root);
assert_eq!(public_key.parameter, decoded.parameter);
}
}