Skip to content
61 changes: 43 additions & 18 deletions src/signature/generalized_xmss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ mod tests {
inc_encoding::target_sum::TargetSumEncoding,
signature::test_templates::test_signature_scheme_correctness,
symmetric::{
message_hash::{MessageHash, poseidon::PoseidonMessageHashW1},
message_hash::{
MessageHash,
poseidon::{PoseidonMessageHash, PoseidonMessageHashW1},
},
prf::shake_to_field::ShakePRFtoF,
tweak_hash::poseidon::PoseidonTweakW1L5,
},
Expand All @@ -602,6 +605,32 @@ mod tests {

type TestTH = PoseidonTweakHash<5, 7, 2, 9, 155>;

fn compress_prf_output<const OUT_LEN: usize>(value: [F; 24]) -> FieldArray<OUT_LEN> {
assert!(
OUT_LEN <= 24,
"Poseidon compression output length must be <= 24"
);
let hash = crate::symmetric::tweak_hash::poseidon::poseidon_compress::<F, _, 24, OUT_LEN>(
&crate::poseidon2_24(),
value.as_slice(),
);
FieldArray(hash)
}

// Compress a wide PRF output down to the four-field-element domain used by PoseidonTweakHash.
impl From<[F; 24]> for FieldArray<4> {
fn from(value: [F; 24]) -> Self {
compress_prf_output::<4>(value)
}
}

// Compress a wide PRF output down to the eight-field-element domain used by PoseidonTweakHash.
impl From<[F; 24]> for FieldArray<8> {
fn from(value: [F; 24]) -> Self {
compress_prf_output::<8>(value)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this


#[test]
pub fn test_target_sum_poseidon() {
// Note: do not use these parameters, they are just for testing
Expand All @@ -613,7 +642,7 @@ mod tests {
const MAX_CHUNK_VALUE: usize = BASE - 1;
const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2;
type IE = TargetSumEncoding<MH, EXPECTED_SUM>;
const LOG_LIFETIME: usize = 6;
const LOG_LIFETIME: usize = 10;
type Sig = GeneralizedXMSSSignatureScheme<PRF, IE, TH, LOG_LIFETIME>;

Sig::internal_consistency_check();
Expand All @@ -635,7 +664,7 @@ mod tests {
const MAX_CHUNK_VALUE: usize = BASE - 1;
const EXPECTED_SUM: usize = NUM_CHUNKS * MAX_CHUNK_VALUE / 2;
type IE = TargetSumEncoding<MH, EXPECTED_SUM>;
const LOG_LIFETIME: usize = 6;
const LOG_LIFETIME: usize = 10;
type Sig = GeneralizedXMSSSignatureScheme<PRF, IE, TH, LOG_LIFETIME>;

Sig::internal_consistency_check();
Expand Down Expand Up @@ -666,15 +695,13 @@ mod tests {
assert_eq!(rho1, rho2);
}

/*#[test]
pub fn test_large_base_sha() {
#[test]
pub fn test_large_base_poseidon() {
// Note: do not use these parameters, they are just for testing
type PRF = ShaPRF<24, 8>;
type TH = ShaTweak192192;

// use chunk size 8
type MH = ShaMessageHash<24, 8, 32, 8>;
const TARGET_SUM: usize = 1 << 12;
type PRF = ShakePRFtoF<24, 8>;
type TH = PoseidonTweakHash<4, 4, 2, 8, 8>;
type MH = PoseidonMessageHash<4, 8, 2, 8, 256, 2, 9>;
const TARGET_SUM: usize = 8 * (256 - 1) / 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we can't have

Suggested change
const TARGET_SUM: usize = 8 * (256 - 1) / 2;
const TARGET_SUM: usize = 1 << 12;

type IE = TargetSumEncoding<MH, TARGET_SUM>;
const LOG_LIFETIME: usize = 10;
type Sig = GeneralizedXMSSSignatureScheme<PRF, IE, TH, LOG_LIFETIME>;
Expand All @@ -686,13 +713,11 @@ mod tests {
}

#[test]
pub fn test_large_dimension_sha() {
pub fn test_large_dimension_poseidon() {
// Note: do not use these parameters, they are just for testing
type PRF = ShaPRF<24, 8>;
type TH = ShaTweak192192;

// use 256 chunks
type MH = ShaMessageHash<24, 8, 256, 1>;
type PRF = ShakePRFtoF<24, 8>;
type TH = PoseidonTweakHash<4, 8, 2, 8, 256>;
type MH = PoseidonMessageHash<4, 8, 8, 256, 2, 2, 9>;
const TARGET_SUM: usize = 128;
type IE = TargetSumEncoding<MH, TARGET_SUM>;
const LOG_LIFETIME: usize = 10;
Expand All @@ -702,7 +727,7 @@ mod tests {

test_signature_scheme_correctness::<Sig>(2, 0, Sig::LIFETIME as usize);
test_signature_scheme_correctness::<Sig>(19, 0, Sig::LIFETIME as usize);
}*/
}

#[test]
pub fn test_expand_activation_time() {
Expand Down