Skip to content

Commit a35156c

Browse files
authored
public key: add ssz impl (#12)
* public key: add ssz impl * fix clippy * better abstraction * trait bound for public key * clippy * rm useless comment
1 parent 857b670 commit a35156c

File tree

8 files changed

+628
-48
lines changed

8 files changed

+628
-48
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ dashmap = "6.1.0"
3939
serde = { version = "1.0", features = ["derive", "alloc"] }
4040
thiserror = "2.0"
4141

42+
ssz = { package = "ethereum_ssz", version = "0.10.0" }
43+
ssz_derive = { package = "ethereum_ssz_derive", version = "0.10.0" }
44+
4245
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
4346
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }
4447
p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "a33a312" }

src/array.rs

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
use serde::{Deserialize, Deserializer, Serialize, de::Visitor};
2+
use ssz::{Decode, DecodeError, Encode};
3+
use std::ops::{Deref, DerefMut};
4+
5+
use crate::F;
6+
use p3_field::{PrimeCharacteristicRing, PrimeField32, RawDataSerializable};
7+
8+
/// A wrapper around an array of field elements that implements SSZ Encode/Decode.
9+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10+
#[repr(transparent)]
11+
pub struct FieldArray<const N: usize>(pub [F; N]);
12+
13+
impl<const N: usize> Deref for FieldArray<N> {
14+
type Target = [F; N];
15+
16+
fn deref(&self) -> &Self::Target {
17+
&self.0
18+
}
19+
}
20+
21+
impl<const N: usize> DerefMut for FieldArray<N> {
22+
fn deref_mut(&mut self) -> &mut Self::Target {
23+
&mut self.0
24+
}
25+
}
26+
27+
impl<const N: usize> From<[F; N]> for FieldArray<N> {
28+
fn from(arr: [F; N]) -> Self {
29+
Self(arr)
30+
}
31+
}
32+
33+
impl<const N: usize> From<FieldArray<N>> for [F; N] {
34+
fn from(field_array: FieldArray<N>) -> Self {
35+
field_array.0
36+
}
37+
}
38+
39+
impl<const N: usize> Encode for FieldArray<N> {
40+
fn is_ssz_fixed_len() -> bool {
41+
true
42+
}
43+
44+
fn ssz_fixed_len() -> usize {
45+
N * F::NUM_BYTES
46+
}
47+
48+
fn ssz_bytes_len(&self) -> usize {
49+
N * F::NUM_BYTES
50+
}
51+
52+
fn ssz_append(&self, buf: &mut Vec<u8>) {
53+
buf.reserve(N * F::NUM_BYTES);
54+
for elem in &self.0 {
55+
let value = elem.as_canonical_u32();
56+
buf.extend_from_slice(&value.to_le_bytes());
57+
}
58+
}
59+
}
60+
61+
impl<const N: usize> Decode for FieldArray<N> {
62+
fn is_ssz_fixed_len() -> bool {
63+
true
64+
}
65+
66+
fn ssz_fixed_len() -> usize {
67+
N * F::NUM_BYTES
68+
}
69+
70+
fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, DecodeError> {
71+
let expected_len = N * F::NUM_BYTES;
72+
if bytes.len() != expected_len {
73+
return Err(DecodeError::InvalidByteLength {
74+
len: bytes.len(),
75+
expected: expected_len,
76+
});
77+
}
78+
79+
let arr = std::array::from_fn(|i| {
80+
let start = i * F::NUM_BYTES;
81+
let chunk = bytes[start..start + F::NUM_BYTES].try_into().unwrap();
82+
F::new(u32::from_le_bytes(chunk))
83+
});
84+
85+
Ok(Self(arr))
86+
}
87+
}
88+
89+
impl<const N: usize> Serialize for FieldArray<N> {
90+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
91+
where
92+
S: serde::Serializer,
93+
{
94+
serializer.collect_seq(self.0.iter().map(PrimeField32::as_canonical_u32))
95+
}
96+
}
97+
98+
impl<'de, const N: usize> Deserialize<'de> for FieldArray<N> {
99+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
100+
where
101+
D: Deserializer<'de>,
102+
{
103+
struct FieldArrayVisitor<const N: usize>;
104+
105+
impl<'de, const N: usize> Visitor<'de> for FieldArrayVisitor<N> {
106+
type Value = FieldArray<N>;
107+
108+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
109+
write!(formatter, "an array of {} field elements", N)
110+
}
111+
112+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
113+
where
114+
A: serde::de::SeqAccess<'de>,
115+
{
116+
let mut arr = [F::ZERO; N];
117+
for (i, p) in arr.iter_mut().enumerate() {
118+
let val: u32 = seq
119+
.next_element()?
120+
.ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
121+
*p = F::new(val);
122+
}
123+
Ok(FieldArray(arr))
124+
}
125+
}
126+
127+
deserializer.deserialize_seq(FieldArrayVisitor::<N>)
128+
}
129+
}
130+
131+
#[cfg(test)]
132+
mod tests {
133+
use super::*;
134+
use proptest::prelude::*;
135+
136+
/// Small parameter arrays
137+
const SMALL_SIZE: usize = 5;
138+
/// Hash output size
139+
const MEDIUM_SIZE: usize = 7;
140+
/// Larger parameter arrays
141+
const LARGE_SIZE: usize = 44;
142+
143+
#[test]
144+
fn test_ssz_roundtrip_zero_values() {
145+
// Start with an array of zeros
146+
let original = FieldArray([F::ZERO; SMALL_SIZE]);
147+
148+
// Encode to bytes using SSZ
149+
let encoded = original.as_ssz_bytes();
150+
151+
// Decode back from bytes
152+
let decoded = FieldArray::<SMALL_SIZE>::from_ssz_bytes(&encoded)
153+
.expect("Failed to decode valid SSZ bytes");
154+
155+
// Verify round-trip preserves the value
156+
assert_eq!(original, decoded, "Round-trip failed for zero values");
157+
}
158+
159+
#[test]
160+
fn test_ssz_roundtrip_max_values() {
161+
// Create array with maximum valid field values
162+
let max_val = F::ORDER_U32 - 1;
163+
let original = FieldArray([F::new(max_val); MEDIUM_SIZE]);
164+
165+
// Perform round-trip encoding/decoding
166+
let encoded = original.as_ssz_bytes();
167+
let decoded = FieldArray::<MEDIUM_SIZE>::from_ssz_bytes(&encoded)
168+
.expect("Failed to decode max values");
169+
170+
// Verify the values survived the round-trip
171+
assert_eq!(original, decoded, "Round-trip failed for max values");
172+
}
173+
174+
#[test]
175+
fn test_ssz_roundtrip_specific_values() {
176+
// Create an array with sequential values for easy verification
177+
let original = FieldArray([F::new(1), F::new(2), F::new(3), F::new(4), F::new(5)]);
178+
179+
// Encode and verify the byte representation
180+
let encoded = original.as_ssz_bytes();
181+
182+
// Each u32 should be encoded as F::NUM_BYTES bytes in little-endian
183+
assert_eq!(
184+
&encoded[0..F::NUM_BYTES],
185+
&[1, 0, 0, 0],
186+
"First element encoding incorrect"
187+
);
188+
assert_eq!(
189+
&encoded[F::NUM_BYTES..2 * F::NUM_BYTES],
190+
&[2, 0, 0, 0],
191+
"Second element encoding incorrect"
192+
);
193+
assert_eq!(
194+
&encoded[2 * F::NUM_BYTES..3 * F::NUM_BYTES],
195+
&[3, 0, 0, 0],
196+
"Third element encoding incorrect"
197+
);
198+
199+
// Decode and verify round-trip
200+
let decoded = FieldArray::<SMALL_SIZE>::from_ssz_bytes(&encoded)
201+
.expect("Failed to decode specific values");
202+
203+
assert_eq!(original, decoded, "Round-trip failed for specific values");
204+
}
205+
206+
#[test]
207+
fn test_ssz_encoding_deterministic() {
208+
let mut rng = rand::rng();
209+
210+
// Create a random field array
211+
let field_array = FieldArray(rng.random::<[F; SMALL_SIZE]>());
212+
213+
// Encode it multiple times
214+
let encoding1 = field_array.as_ssz_bytes();
215+
let encoding2 = field_array.as_ssz_bytes();
216+
let encoding3 = field_array.as_ssz_bytes();
217+
218+
// All encodings should be identical
219+
assert_eq!(encoding1, encoding2, "Encoding not deterministic (1 vs 2)");
220+
assert_eq!(encoding2, encoding3, "Encoding not deterministic (2 vs 3)");
221+
}
222+
223+
#[test]
224+
fn test_ssz_encoded_size() {
225+
let field_array = FieldArray([F::ZERO; LARGE_SIZE]);
226+
let encoded = field_array.as_ssz_bytes();
227+
228+
// Verify the encoded size matches expectations
229+
let expected_size = LARGE_SIZE * F::NUM_BYTES;
230+
assert_eq!(
231+
encoded.len(),
232+
expected_size,
233+
"Encoded size should be {} bytes (array of {} elements, {} bytes each)",
234+
expected_size,
235+
LARGE_SIZE,
236+
F::NUM_BYTES
237+
);
238+
239+
// Also verify the trait method reports the same size
240+
assert_eq!(
241+
field_array.ssz_bytes_len(),
242+
expected_size,
243+
"ssz_bytes_len() should match actual encoded size"
244+
);
245+
}
246+
247+
#[test]
248+
fn test_ssz_decode_rejects_wrong_length() {
249+
let expected_len = SMALL_SIZE * F::NUM_BYTES;
250+
251+
// Test buffer that's too short (missing one byte)
252+
let too_short = vec![0u8; expected_len - 1];
253+
let result = FieldArray::<SMALL_SIZE>::from_ssz_bytes(&too_short);
254+
assert!(result.is_err(), "Should reject buffer that's too short");
255+
if let Err(DecodeError::InvalidByteLength { len, expected }) = result {
256+
assert_eq!(len, expected_len - 1);
257+
assert_eq!(expected, expected_len);
258+
} else {
259+
panic!("Expected InvalidByteLength error");
260+
}
261+
262+
// Test buffer that's too long (extra byte)
263+
let too_long = vec![0u8; expected_len + 1];
264+
let result = FieldArray::<SMALL_SIZE>::from_ssz_bytes(&too_long);
265+
assert!(result.is_err(), "Should reject buffer that's too long");
266+
if let Err(DecodeError::InvalidByteLength { len, expected }) = result {
267+
assert_eq!(len, expected_len + 1);
268+
assert_eq!(expected, expected_len);
269+
} else {
270+
panic!("Expected InvalidByteLength error");
271+
}
272+
}
273+
274+
#[test]
275+
fn test_ssz_fixed_len_trait_methods() {
276+
// Arrays are always fixed-length in SSZ
277+
assert!(
278+
<FieldArray<SMALL_SIZE> as Encode>::is_ssz_fixed_len(),
279+
"FieldArray should report as fixed-length (Encode)"
280+
);
281+
assert!(
282+
<FieldArray<SMALL_SIZE> as Decode>::is_ssz_fixed_len(),
283+
"FieldArray should report as fixed-length (Decode)"
284+
);
285+
286+
// The fixed length should be N * F::NUM_BYTES
287+
let expected_len = SMALL_SIZE * F::NUM_BYTES;
288+
assert_eq!(
289+
<FieldArray<SMALL_SIZE> as Encode>::ssz_fixed_len(),
290+
expected_len,
291+
"Encode::ssz_fixed_len() incorrect"
292+
);
293+
assert_eq!(
294+
<FieldArray<SMALL_SIZE> as Decode>::ssz_fixed_len(),
295+
expected_len,
296+
"Decode::ssz_fixed_len() incorrect"
297+
);
298+
}
299+
300+
proptest! {
301+
#[test]
302+
fn proptest_ssz_roundtrip_large(
303+
values in prop::collection::vec(0u32..F::ORDER_U32, LARGE_SIZE)
304+
) {
305+
// Convert Vec to array for large sizes
306+
let arr: [F; LARGE_SIZE] = std::array::from_fn(|i| F::new(values[i]));
307+
let original = FieldArray(arr);
308+
309+
let encoded = original.as_ssz_bytes();
310+
let decoded = FieldArray::<LARGE_SIZE>::from_ssz_bytes(&encoded)
311+
.expect("Valid SSZ bytes should always decode");
312+
313+
prop_assert_eq!(original, decoded);
314+
}
315+
316+
#[test]
317+
fn proptest_ssz_deterministic(
318+
values in prop::array::uniform5(0u32..F::ORDER_U32)
319+
) {
320+
let arr = values.map(F::new);
321+
let field_array = FieldArray(arr);
322+
323+
// Encode twice and verify both encodings are identical
324+
let encoding1 = field_array.as_ssz_bytes();
325+
let encoding2 = field_array.as_ssz_bytes();
326+
327+
prop_assert_eq!(encoding1, encoding2);
328+
}
329+
330+
#[test]
331+
fn proptest_ssz_size_invariant(
332+
values in prop::array::uniform5(0u32..F::ORDER_U32)
333+
) {
334+
let arr = values.map(F::new);
335+
let field_array = FieldArray(arr);
336+
337+
let encoded = field_array.as_ssz_bytes();
338+
let expected_size = SMALL_SIZE * F::NUM_BYTES;
339+
340+
prop_assert_eq!(encoded.len(), expected_size);
341+
prop_assert_eq!(field_array.ssz_bytes_len(), expected_size);
342+
}
343+
}
344+
345+
#[test]
346+
fn test_equality() {
347+
let arr1 = FieldArray([F::new(1), F::new(2), F::new(3)]);
348+
let arr2 = FieldArray([F::new(1), F::new(2), F::new(3)]);
349+
let arr3 = FieldArray([F::new(1), F::new(2), F::new(4)]);
350+
351+
// Equal arrays should be equal
352+
assert_eq!(arr1, arr2);
353+
354+
// Different arrays should not be equal
355+
assert_ne!(arr1, arr3);
356+
assert_ne!(arr2, arr3);
357+
}
358+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub const TWEAK_SEPARATOR_FOR_CHAIN_HASH: u8 = 0x00;
1414
type F = KoalaBear;
1515
pub(crate) type PackedF = <F as Field>::Packing;
1616

17+
pub(crate) mod array;
1718
pub(crate) mod hypercube;
1819
pub(crate) mod inc_encoding;
1920
pub mod signature;

src/signature.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Range;
33
use crate::MESSAGE_LENGTH;
44
use rand::Rng;
55
use serde::{Serialize, de::DeserializeOwned};
6+
use ssz::{Decode, Encode};
67
use thiserror::Error;
78

89
/// Error enum for the signing process.
@@ -96,7 +97,9 @@ pub trait SignatureScheme {
9697
/// The public key used for verification.
9798
///
9899
/// The key must be serializable to allow for network transmission and storage.
99-
type PublicKey: Serialize + DeserializeOwned;
100+
///
101+
/// We must support SSZ encoding for Ethereum consensus layer compatibility.
102+
type PublicKey: Serialize + DeserializeOwned + Encode + Decode;
100103

101104
/// The secret key used for signing.
102105
///

0 commit comments

Comments
 (0)