diff --git a/coding/conformance.toml b/coding/conformance.toml index e5bcf2fe41..524762cd37 100644 --- a/coding/conformance.toml +++ b/coding/conformance.toml @@ -8,7 +8,7 @@ hash = "edc22446bb2952609d0c8daccf2d22f8ad2b71eedfcf45296c3f4db49d78404a" ["commonware_coding::reed_solomon::tests::conformance::CodecConformance>"] n_cases = 65536 -hash = "aa8512bb8e86e967833edd1a6cc806280d5e7334e9dc8428a098de9204db12d1" +hash = "45545e4d4aeb18b8bdb019e630fb9a1fa6dda9ed32b2d529a9213ec07ccab07c" ["commonware_coding::test::conformance::CodecConformance"] n_cases = 65536 @@ -16,8 +16,8 @@ hash = "1a412c5c279f981857081765537b85474184048d1b17053394f94fc42ac1dbf4" ["commonware_coding::zoda::tests::conformance::CodecConformance>"] n_cases = 65536 -hash = "ebbbe08eb9beb1c5215a5d67ad9deddaef7c54920e53a751b56a8261e60e0e52" +hash = "0571442797c611b3822c8a9c54138de9f54fc5b9daaf01796f611a5c74466710" ["commonware_coding::zoda::tests::conformance::CodecConformance>"] n_cases = 65536 -hash = "929ce4f95f9d5784f995c52b7e5cde8b62663ab068848925314dc9f80eb27d34" +hash = "fbf783e8550fe15cd7000f8185e1ca3bc9641ba0baf156ba6365d3b224e2222d" diff --git a/coding/src/reed_solomon.rs b/coding/src/reed_solomon.rs index e1a0673622..9a073abd1a 100644 --- a/coding/src/reed_solomon.rs +++ b/coding/src/reed_solomon.rs @@ -47,13 +47,13 @@ pub struct Chunk { /// The index of [Chunk] in the original data. index: u16, - /// The proof of the shard in the [bmt] at the given index. - proof: bmt::Proof, + /// The multi-proof of the shard in the [bmt] at the given index. + proof: bmt::Proof, } impl Chunk { /// Create a new [Chunk] from the given shard, index, and proof. - const fn new(shard: Vec, index: u16, proof: bmt::Proof) -> Self { + const fn new(shard: Vec, index: u16, proof: bmt::Proof) -> Self { Self { shard, index, @@ -75,7 +75,7 @@ impl Chunk { // Verify proof self.proof - .verify(&mut hasher, &shard_digest, self.index as u32, root) + .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root) .is_ok() } } @@ -95,7 +95,7 @@ impl Read for Chunk { fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result { let shard = Vec::::read_range(reader, ..=cfg.maximum_shard_size)?; let index = u16::read(reader)?; - let proof = bmt::Proof::::read(reader)?; + let proof = bmt::Proof::::read_cfg(reader, &1)?; Ok(Self { shard, index, @@ -381,7 +381,7 @@ fn decode( /// The encoder takes input data, splits it into `k` data shards, and generates `m` recovery /// shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction). /// All `n = k + m` shards are then used to build a [bmt], producing a single root hash. Each shard -/// is packaged as a chunk containing the shard data, its index, and a Merkle proof against the [bmt] root. +/// is packaged as a chunk containing the shard data, its index, and a Merkle multi-proof against the [bmt] root. /// /// ## Encoding /// @@ -445,12 +445,12 @@ fn decode( /// Each chunk contains: /// - `shard`: The shard data (original or recovery). /// - `index`: The shard's original index (0 to n-1). -/// - `proof`: A Merkle proof of the shard's inclusion in the [bmt]. +/// - `proof`: A Merkle multi-proof of the shard's inclusion in the [bmt]. /// /// ## Decoding and Verification /// /// The decoder requires any `k` chunks to reconstruct the original data. -/// 1. Each chunk's Merkle proof is verified against the [bmt] root. +/// 1. Each chunk's Merkle multi-proof is verified against the [bmt] root. /// 2. The shards from the valid chunks are used to reconstruct the original `k` data shards. /// 3. To ensure consistency, the recovered data shards are re-encoded, and a new [bmt] root is /// generated. This new root MUST match the original [bmt] root. This prevents attacks where diff --git a/coding/src/zoda.rs b/coding/src/zoda.rs index 2eff324dc1..51dca6547c 100644 --- a/coding/src/zoda.rs +++ b/coding/src/zoda.rs @@ -126,10 +126,7 @@ use commonware_math::{ ntt::{EvaluationVector, Matrix}, }; use commonware_parallel::Strategy; -use commonware_storage::mmr::{ - mem::DirtyMmr, verification::multi_proof, Error as MmrError, Location, Proof, StandardHasher, -}; -use futures::executor::block_on; +use commonware_storage::bmt::{Builder as BmtBuilder, Error as BmtError, Proof}; use rand::seq::SliceRandom as _; use std::{marker::PhantomData, sync::Arc}; use thiserror::Error; @@ -439,7 +436,6 @@ impl Read for ReShard { let max_data_els = F::bits_to_elements(max_data_bits).max(1); Ok(Self { // Worst case: every row is one data element, and the sample size is all rows. - // TODO (#2506): use correct bounds on inclusion proof size inclusion_proof: Read::read_cfg(buf, &max_data_els)?, shard: Read::read_cfg(buf, &max_data_els)?, }) @@ -468,8 +464,15 @@ pub struct CheckedShard { /// Take indices up to `total`, and shuffle them. /// /// The shuffle depends, deterministically, on the transcript. -fn shuffle_indices(transcript: &Transcript, total: usize) -> Vec { - let mut out = (0..total as u64).map(Location::from).collect::>(); +/// +/// # Panics +/// +/// Panics if `total` exceeds `u32::MAX`. +fn shuffle_indices(transcript: &Transcript, total: usize) -> Vec { + let total: u32 = total + .try_into() + .expect("encoded_rows exceeds u32::MAX; data too large for ZODA"); + let mut out = (0..total).collect::>(); out.shuffle(&mut transcript.noise(b"shuffle")); out } @@ -492,7 +495,7 @@ pub struct CheckingData { root: H::Digest, checking_matrix: Matrix, encoded_checksum: Matrix, - shuffled_indices: Vec, + shuffled_indices: Vec, } impl CheckingData { @@ -550,24 +553,29 @@ impl CheckingData { let index = index as usize; let these_shuffled_indices = &self.shuffled_indices [index * self.topology.samples..(index + 1) * self.topology.samples]; - let proof_elements = { - these_shuffled_indices - .iter() - .zip(reshard.shard.iter()) - .map(|(&i, row)| (row_digest::(row), i)) - .collect::>() - }; - if !reshard.inclusion_proof.verify_multi_inclusion( - &mut StandardHasher::::new(), - &proof_elements, - &self.root, - ) { + + // Build elements for BMT multi-proof verification using the deterministically + // computed indices for this shard + let proof_elements: Vec<(H::Digest, u32)> = these_shuffled_indices + .iter() + .zip(reshard.shard.iter()) + .map(|(&i, row)| (row_digest::(row), i)) + .collect(); + + // Verify the multi-proof + let mut hasher = H::new(); + if reshard + .inclusion_proof + .verify_multi_inclusion(&mut hasher, &proof_elements, &self.root) + .is_err() + { return Err(Error::InvalidReShard); } + let shard_checksum = reshard.shard.mul(&self.checking_matrix); // Check that the shard checksum rows match the encoded checksums for (row, &i) in shard_checksum.iter().zip(these_shuffled_indices) { - if row != &self.encoded_checksum[u64::from(i) as usize] { + if row != &self.encoded_checksum[i as usize] { return Err(Error::InvalidReShard); } } @@ -591,7 +599,7 @@ pub enum Error { #[error("insufficient unique rows {0} < {1}")] InsufficientUniqueRows(usize, usize), #[error("failed to create inclusion proof: {0}")] - FailedToCreateInclusionProof(MmrError), + FailedToCreateInclusionProof(BmtError), } // TODO (#2506): rename this to `_COMMONWARE_CODING_ZODA` @@ -642,17 +650,16 @@ impl Scheme for Zoda { .evaluate() .data(); - // Step 3: Commit to the rows of the data. - let mut hasher = StandardHasher::::new(); - let mut mmr = DirtyMmr::new(); - let row_hashes = strategy.map_collect_vec(0..encoded_data.rows(), |i| { + // Step 3: Commit to the rows of the data using a Binary Merkle Tree. + let row_hashes: Vec = strategy.map_collect_vec(0..encoded_data.rows(), |i| { row_digest::(&encoded_data[i]) }); + let mut bmt_builder = BmtBuilder::::new(row_hashes.len()); for hash in &row_hashes { - mmr.add(&mut hasher, hash); + bmt_builder.add(hash); } - let mmr = mmr.merkleize(&mut hasher, None); - let root = *mmr.root(); + let bmt = bmt_builder.build(); + let root = bmt.root(); // Step 4: Commit to the root, and the size of the data. let mut transcript = Transcript::new(NAMESPACE); @@ -668,20 +675,20 @@ impl Scheme for Zoda { // Step 6: Multiply the data with the checking matrix. let checksum = Arc::new(data.mul(&checking_matrix)); - // Step 7: Produce the shards. - // We can't use "chunks" because we need to handle a sample size of 0 - let index_chunks = (0..topology.total_shards) - .map(|i| &shuffled_indices[i * topology.samples..(i + 1) * topology.samples]); - let shards = index_chunks - .map(|indices| { + // Step 7: Produce the shards in parallel. + let shard_results: Vec, Error>> = + strategy.map_collect_vec(0..topology.total_shards, |shard_idx| { + let indices = &shuffled_indices + [shard_idx * topology.samples..(shard_idx + 1) * topology.samples]; let rows = Matrix::init( indices.len(), topology.data_cols, indices .iter() - .flat_map(|&i| encoded_data[u64::from(i) as usize].iter().copied()), + .flat_map(|&i| encoded_data[i as usize].iter().copied()), ); - let inclusion_proof = block_on(multi_proof(&mmr, indices)) + let inclusion_proof = bmt + .multi_proof(indices) .map_err(Error::FailedToCreateInclusionProof)?; Ok(Shard { data_bytes, @@ -690,7 +697,9 @@ impl Scheme for Zoda { rows, checksum: checksum.clone(), }) - }) + }); + let shards = shard_results + .into_iter() .collect::, Error>>()?; Ok((commitment, shards)) } diff --git a/storage/conformance.toml b/storage/conformance.toml index 5d7f488019..237553081f 100644 --- a/storage/conformance.toml +++ b/storage/conformance.toml @@ -18,11 +18,11 @@ hash = "3cb6882637c1c1a929a50b3ab425311f3ef342184dc46a80b1eae616ca7b64a4" n_cases = 65536 hash = "20f5ef35a4bbd3a40852e907df519c724e5ce24d9f929e84947fd971a2256d02" -["commonware_storage::bmt::tests::conformance::CodecConformance>"] +["commonware_storage::bmt::tests::conformance::CodecConformance>"] n_cases = 65536 -hash = "c1f1d4c35fcd50931d7c36cbcddbb1c0a93afef9a93945cdd3efadf68ff53328" +hash = "6d6382956289a2f706581a4b1afa08c5cd8e8a4f55b11d454425333b6537cc17" -["commonware_storage::bmt::tests::conformance::CodecConformance>"] +["commonware_storage::bmt::tests::conformance::CodecConformance>"] n_cases = 65536 hash = "6ecb0491b09443f1f93c178af5472f138ddc71b3e8c0c106f32eafca617b56af" diff --git a/storage/fuzz/fuzz_targets/bmt_operations.rs b/storage/fuzz/fuzz_targets/bmt_operations.rs index ef2844e1e4..e073d65d08 100644 --- a/storage/fuzz/fuzz_targets/bmt_operations.rs +++ b/storage/fuzz/fuzz_targets/bmt_operations.rs @@ -1,8 +1,8 @@ #![no_main] use arbitrary::Arbitrary; -use commonware_codec::{DecodeExt, Encode}; -use commonware_cryptography::{Hasher as _, Sha256}; +use commonware_codec::{Decode, DecodeExt, Encode}; +use commonware_cryptography::{sha256::Digest as Sha256Digest, Hasher as _, Sha256}; use commonware_storage::bmt::{Builder, Proof, RangeProof}; use libfuzzer_sys::fuzz_target; @@ -61,6 +61,28 @@ enum BmtOperation { start: u32, tampered_values: Vec, }, + // Multi-proof operations + GenerateMultiProof { + positions: Vec, + }, + VerifyMultiProof { + elements: Vec<(u64, u32)>, + }, + DeserializeMultiProof { + data: Vec, + max_items: u8, + }, + // Multi-proof edge cases + MultiProofDuplicatePositions { + position: u32, + count: u8, + }, + VerifyMultiProofWrongElements { + tampered_elements: Vec<(u64, u32)>, + }, + VerifyMultiProofPartialElements { + skip_count: u8, + }, } #[derive(Arbitrary, Debug)] @@ -71,8 +93,10 @@ struct FuzzInput { fn fuzz(input: FuzzInput) { let mut builder: Option> = None; let mut tree = None; - let mut proof = None; + let mut proof: Option> = None; let mut range_proof = None; + let mut multi_proof: Option> = None; + let mut multi_proof_positions: Vec = Vec::new(); let mut leaf_values = Vec::new(); for op in input.operations.iter() { @@ -119,7 +143,7 @@ fn fuzz(input: FuzzInput) { let mut hasher = Sha256::default(); let leaf_digest = Sha256::hash(&leaf_value.to_be_bytes()); let root = t.root(); - let _ = p.verify(&mut hasher, &leaf_digest, *position, &root); + let _ = p.verify_element_inclusion(&mut hasher, &leaf_digest, *position, &root); } } @@ -130,7 +154,8 @@ fn fuzz(input: FuzzInput) { } BmtOperation::DeserializeProof { data } => { - let _ = Proof::::decode(&mut data.as_slice()); + // Use max_items=1 since we're fuzzing single-element proofs + let _ = Proof::::decode_cfg(&mut data.as_slice(), &1); } BmtOperation::BuildEmptyTree => { @@ -186,7 +211,7 @@ fn fuzz(input: FuzzInput) { } BmtOperation::DeserializeRangeProof { data } => { - let _ = RangeProof::::decode(&mut data.as_slice()); + let _ = RangeProof::::decode(&mut data.as_slice()); } // Range proof edge cases @@ -256,6 +281,85 @@ fn fuzz(input: FuzzInput) { let _ = rp.verify(&mut hasher, *start, &tampered_digests, &root); } } + + // Multi-proof operations + BmtOperation::GenerateMultiProof { positions } => { + if let Some(ref t) = tree { + // Limit positions to avoid excessive memory usage + let limited_positions: Vec = positions.iter().take(20).copied().collect(); + if let Ok(mp) = t.multi_proof(&limited_positions) { + multi_proof = Some(mp); + multi_proof_positions = limited_positions; + } + } + } + + BmtOperation::VerifyMultiProof { elements } => { + if let (Some(ref mp), Some(ref t)) = (&multi_proof, &tree) { + let mut hasher = Sha256::default(); + // Convert (value, position) pairs to (digest, position) + let element_digests: Vec<_> = elements + .iter() + .take(20) // Limit elements + .map(|(v, pos)| (Sha256::hash(&v.to_be_bytes()), *pos)) + .collect(); + let root = t.root(); + let _ = mp.verify_multi_inclusion(&mut hasher, &element_digests, &root); + } + } + + BmtOperation::DeserializeMultiProof { data, max_items } => { + // Use max_items from fuzz input, clamped to reasonable range + let max = (*max_items as usize).clamp(1, 100); + let _ = Proof::::decode_cfg(&mut data.as_slice(), &max); + } + + BmtOperation::MultiProofDuplicatePositions { position, count } => { + if let Some(ref t) = tree { + // Create a positions array with duplicates + let count = (*count as usize).clamp(2, 10); + let positions: Vec = vec![*position; count]; + if let Ok(mp) = t.multi_proof(&positions) { + multi_proof = Some(mp); + multi_proof_positions = positions; + } + } + } + + BmtOperation::VerifyMultiProofWrongElements { tampered_elements } => { + if let (Some(ref mp), Some(ref t)) = (&multi_proof, &tree) { + let mut hasher = Sha256::default(); + // Convert tampered (value, position) pairs to (digest, position) + let tampered_digests: Vec<_> = tampered_elements + .iter() + .take(20) + .map(|(v, pos)| (Sha256::hash(&v.to_be_bytes()), *pos)) + .collect(); + let root = t.root(); + let _ = mp.verify_multi_inclusion(&mut hasher, &tampered_digests, &root); + } + } + + BmtOperation::VerifyMultiProofPartialElements { skip_count } => { + if let (Some(ref mp), Some(ref t)) = (&multi_proof, &tree) { + if !multi_proof_positions.is_empty() && !leaf_values.is_empty() { + let mut hasher = Sha256::default(); + // Skip some elements from the original proof + let skip = (*skip_count as usize) % multi_proof_positions.len().max(1); + let partial_elements: Vec<_> = multi_proof_positions + .iter() + .skip(skip) + .filter_map(|&pos| { + leaf_values + .get(pos as usize) + .map(|v| (Sha256::hash(&v.to_be_bytes()), pos)) + }) + .collect(); + let root = t.root(); + let _ = mp.verify_multi_inclusion(&mut hasher, &partial_elements, &root); + } + } + } } } } diff --git a/storage/fuzz/fuzz_targets/proofs_malleability.rs b/storage/fuzz/fuzz_targets/proofs_malleability.rs index ab3cfffb8e..f9b2cf46b3 100644 --- a/storage/fuzz/fuzz_targets/proofs_malleability.rs +++ b/storage/fuzz/fuzz_targets/proofs_malleability.rs @@ -169,24 +169,24 @@ fn fuzz(input: FuzzInput) { let tree = builder.build(); let root = tree.root(); - for (idx, _) in digests.iter().enumerate() { + for (idx, digest) in digests.iter().enumerate() { let original_proof = tree.proof(idx as u32).unwrap(); let mut hasher = Sha256::default(); assert!( original_proof - .verify(&mut hasher, &digests[idx], idx as u32, &root) + .verify_element_inclusion(&mut hasher, digest, idx as u32, &root) .is_ok(), "Original BMT proof must be valid" ); for mutation in &input.mutations { let mut mutated_proof = original_proof.clone(); - mutate_proof_bytes(&mut mutated_proof, mutation, &()); + mutate_proof_bytes(&mut mutated_proof, mutation, &1); if mutated_proof != original_proof { let is_valid = mutated_proof - .verify(&mut hasher, &digests[idx], idx as u32, &root) + .verify_element_inclusion(&mut hasher, digest, idx as u32, &root) .is_ok(); assert!(!is_valid, "Mutated BMT proof must be invalid"); } diff --git a/storage/src/bmt/benches/bench.rs b/storage/src/bmt/benches/bench.rs index 1bd3953681..174f61896d 100644 --- a/storage/src/bmt/benches/bench.rs +++ b/storage/src/bmt/benches/bench.rs @@ -1,7 +1,13 @@ use criterion::criterion_main; mod build; +mod prove_multi; mod prove_range; mod prove_single; -criterion_main!(build::benches, prove_single::benches, prove_range::benches); +criterion_main!( + build::benches, + prove_single::benches, + prove_multi::benches, + prove_range::benches +); diff --git a/storage/src/bmt/benches/prove_multi.rs b/storage/src/bmt/benches/prove_multi.rs new file mode 100644 index 0000000000..1db92fcc3f --- /dev/null +++ b/storage/src/bmt/benches/prove_multi.rs @@ -0,0 +1,56 @@ +use commonware_cryptography::{sha256, Hasher, Sha256}; +use commonware_math::algebra::Random as _; +use commonware_storage::bmt::Builder; +use criterion::{criterion_group, Criterion}; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; + +const SAMPLE_SIZE: usize = 100; + +fn bench_prove_multi(c: &mut Criterion) { + for n in [250, 1_000, 5_000, 10_000, 25_000, 50_000, 100_000] { + // Populate Binary Merkle Tree + let mut builder = Builder::::new(n); + let mut queries = Vec::with_capacity(n); + let mut sampler = StdRng::seed_from_u64(0); + for pos in 0..n { + let element = sha256::Digest::random(&mut sampler); + builder.add(&element); + queries.push((pos as u32, element)); + } + let tree = builder.build(); + let root = tree.root(); + + // Select SAMPLE_SIZE random elements without replacement and create/verify multi-proof + c.bench_function( + &format!("{}/n={} items={}", module_path!(), n, SAMPLE_SIZE), + |b| { + b.iter_batched( + || { + let samples: Vec<_> = queries + .choose_multiple(&mut sampler, SAMPLE_SIZE) + .cloned() + .collect(); + let positions: Vec = samples.iter().map(|(pos, _)| *pos).collect(); + let proof = tree.multi_proof(&positions).unwrap(); + (samples, proof) + }, + |(samples, proof)| { + let mut hasher = Sha256::new(); + let elements: Vec<_> = + samples.iter().map(|(pos, elem)| (*elem, *pos)).collect(); + assert!(proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + }, + criterion::BatchSize::SmallInput, + ) + }, + ); + } +} + +criterion_group! { + name = benches; + config = Criterion::default().sample_size(10); + targets = bench_prove_multi +} diff --git a/storage/src/bmt/benches/prove_single.rs b/storage/src/bmt/benches/prove_single.rs index e25f13cd2e..d4fb6304e0 100644 --- a/storage/src/bmt/benches/prove_single.rs +++ b/storage/src/bmt/benches/prove_single.rs @@ -36,7 +36,9 @@ fn bench_prove_single(c: &mut Criterion) { let mut hasher = Sha256::new(); for (pos, element) in samples { let proof = tree.proof(pos).unwrap(); - assert!(proof.verify(&mut hasher, &element, pos, &root).is_ok()); + assert!(proof + .verify_element_inclusion(&mut hasher, &element, pos, &root) + .is_ok()); } }, criterion::BatchSize::SmallInput, diff --git a/storage/src/bmt/mod.rs b/storage/src/bmt/mod.rs index dc322e68cf..78aba127b7 100644 --- a/storage/src/bmt/mod.rs +++ b/storage/src/bmt/mod.rs @@ -12,9 +12,9 @@ //! Level 0 (leaves): [hash(0,A),hash(1,B),hash(2,C)] //! ``` //! -//! A proof for a given leaf is generated by collecting the sibling at each level (from the leaf up to the root). -//! An external process can then use this proof (with some trusted root) to verify that the leaf (at a fixed position) -//! is part of the tree. +//! A proof for one or more leaves is generated by collecting the siblings needed to reconstruct the root. +//! An external process can then use this proof (with some trusted root) to verify that the leaves +//! are part of the tree. //! //! # Example //! @@ -37,9 +37,10 @@ //! // Generate a proof for leaf at index 1 //! let mut hasher = Sha256::default(); //! let proof = tree.proof(1).unwrap(); -//! assert!(proof.verify(&mut hasher, &digests[1], 1, &root).is_ok()); +//! assert!(proof.verify_element_inclusion(&mut hasher, &digests[1], 1, &root).is_ok()); //! ``` +use alloc::collections::btree_set::BTreeSet; use bytes::{Buf, BufMut}; use commonware_codec::{EncodeSize, Read, ReadExt, ReadRangeExt, Write}; use commonware_cryptography::{Digest, Hasher}; @@ -48,7 +49,7 @@ use thiserror::Error; /// There should never be more than 255 levels in a proof (would mean the Binary Merkle Tree /// has more than 2^255 leaves). -const MAX_LEVELS: usize = u8::MAX as usize; +pub const MAX_LEVELS: usize = u8::MAX as usize; /// Errors that can occur when working with a Binary Merkle Tree (BMT). #[derive(Error, Debug)] @@ -61,6 +62,8 @@ pub enum Error { NoLeaves, #[error("unaligned proof")] UnalignedProof, + #[error("duplicate position: {0}")] + DuplicatePosition(u32), } /// Constructor for a Binary Merkle Tree (BMT). @@ -158,39 +161,10 @@ impl Tree { /// Generates a Merkle proof for the leaf at `position`. /// - /// The proof contains the sibling digest at each level needed to reconstruct - /// the root. - pub fn proof(&self, position: u32) -> Result, Error> { - // Ensure the position is within bounds - if self.empty || position >= self.levels.first().len().get() as u32 { - return Err(Error::InvalidPosition(position)); - } - - // For each level (except the root level) record the sibling - let mut siblings = Vec::with_capacity(self.levels.len().get() - 1); - let mut index = position as usize; - for level in &self.levels { - if level.is_singleton() { - break; - } - let sibling_index = if index.is_multiple_of(2) { - index + 1 - } else { - index - 1 - }; - let sibling = if sibling_index < level.len().get() { - level[sibling_index] - } else { - // If no right child exists, use a duplicate of the current node. - // - // This doesn't affect the robustness of the proof (allow a non-existent position - // to be proven or enable multiple proofs to be generated from a single leaf). - level[index] - }; - siblings.push(sibling); - index /= 2; - } - Ok(Proof { siblings }) + /// This is a single-element multi-proof, which includes the minimal siblings + /// needed to reconstruct the root. + pub fn proof(&self, position: u32) -> Result, Error> { + self.multi_proof(&[position]) } /// Generates a Merkle range proof for a contiguous set of leaves from `start` @@ -199,7 +173,7 @@ impl Tree { /// The proof contains the minimal set of sibling digests needed to reconstruct /// the root for all elements in the range. This is more efficient than individual /// proofs when proving multiple consecutive elements. - pub fn range_proof(&self, start: u32, end: u32) -> Result, Error> { + pub fn range_proof(&self, start: u32, end: u32) -> Result, Error> { // For empty trees, return an empty proof if self.empty && start == 0 && end == 0 { return Ok(RangeProof::default()); @@ -256,96 +230,41 @@ impl Tree { Ok(RangeProof { siblings }) } -} - -/// A Merkle proof for a leaf in a Binary Merkle Tree. -#[derive(Clone, Debug, Eq)] -pub struct Proof { - /// The sibling hashes from the leaf up to the root. - pub siblings: Vec, -} - -impl PartialEq for Proof -where - H::Digest: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - self.siblings == other.siblings - } -} - -impl Write for Proof { - fn write(&self, writer: &mut impl BufMut) { - self.siblings.write(writer); - } -} -impl Read for Proof { - type Cfg = (); - - fn read_cfg(reader: &mut impl Buf, _: &Self::Cfg) -> Result { - let siblings = Vec::::read_range(reader, ..=MAX_LEVELS)?; - Ok(Self { siblings }) - } -} + /// Generates a Merkle proof for multiple non-contiguous leaves at the given `positions`. + /// + /// The proof contains the minimal set of sibling digests needed to reconstruct + /// the root for all elements at the specified positions. This is more efficient + /// than individual proofs when proving multiple elements because shared siblings + /// are deduplicated. + /// + /// Positions are sorted internally; duplicate positions will return an error. + pub fn multi_proof(&self, positions: &[u32]) -> Result, Error> { + // Handle empty positions first - can't prove zero elements + if positions.is_empty() { + return Err(Error::NoLeaves); + } -impl EncodeSize for Proof { - fn encode_size(&self) -> usize { - self.siblings.encode_size() - } -} + // Handle empty tree case + if self.empty { + return Err(Error::InvalidPosition(positions[0])); + } -impl Proof { - /// Verifies that a given `leaf` at `position` is included in a Binary Merkle Tree - /// with `root` using the provided `hasher`. - /// - /// The proof consists of sibling hashes stored from the leaf up to the root. At each level, if the current - /// node is a left child (even index), the sibling is combined to the right; if it is a right child (odd index), - /// the sibling is combined to the left. - pub fn verify( - &self, - hasher: &mut H, - leaf: &H::Digest, - mut position: u32, - root: &H::Digest, - ) -> Result<(), Error> { - // Compute the position-hashed leaf - hasher.update(&position.to_be_bytes()); - hasher.update(leaf); - let mut computed = hasher.finalize(); - for sibling in self.siblings.iter() { - // Determine the position of the sibling - let (left_node, right_node) = if position.is_multiple_of(2) { - (&computed, sibling) - } else { - (sibling, &computed) - }; + let leaf_count = self.levels.first().len().get() as u32; - // Compute the parent digest - hasher.update(left_node); - hasher.update(right_node); - computed = hasher.finalize(); + // Get required sibling positions (this validates positions and checks for duplicates) + let sibling_positions = + siblings_required_for_multi_proof(leaf_count, positions.iter().copied())?; - // Move up the tree - position /= 2; - } - let result = computed == *root; - if result { - Ok(()) - } else { - Err(Error::InvalidProof(computed.to_string(), root.to_string())) - } - } -} + // Collect sibling digests in order + let siblings: Vec = sibling_positions + .iter() + .map(|&(level, index)| self.levels[level][index]) + .collect(); -#[cfg(feature = "arbitrary")] -impl arbitrary::Arbitrary<'_> for Proof -where - H::Digest: for<'a> arbitrary::Arbitrary<'a>, -{ - fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { - Ok(Self { - siblings: u.arbitrary()?, + Ok(Proof { + leaf_count, + siblings, }) } } @@ -404,34 +323,25 @@ where } /// A Merkle range proof for a contiguous set of leaves in a Binary Merkle Tree. -#[derive(Clone, Debug, Eq)] -pub struct RangeProof { +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RangeProof { /// The sibling digests needed to prove all elements in the range. /// /// Organized by level, from leaves to root. Each level can have at most /// 2 siblings (one on the left boundary and one on the right boundary). - pub siblings: Vec>, -} - -impl PartialEq for RangeProof -where - H::Digest: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - self.siblings == other.siblings - } + pub siblings: Vec>, } -impl Default for RangeProof { +impl Default for RangeProof { fn default() -> Self { Self { siblings: vec![] } } } #[cfg(feature = "arbitrary")] -impl arbitrary::Arbitrary<'_> for RangeProof +impl arbitrary::Arbitrary<'_> for RangeProof where - H::Digest: for<'a> arbitrary::Arbitrary<'a>, + D: for<'a> arbitrary::Arbitrary<'a>, { fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { Ok(Self { @@ -446,18 +356,18 @@ struct Node { digest: D, } -impl RangeProof { +impl RangeProof { /// Verifies that a given range of `leaves` starting at `position` are included /// in a Binary Merkle Tree with `root` using the provided `hasher`. /// /// The proof contains the set of sibling digests needed to reconstruct /// the root for all elements in the range. - pub fn verify( + pub fn verify>( &self, hasher: &mut H, position: u32, - leaves: &[H::Digest], - root: &H::Digest, + leaves: &[D], + root: &D, ) -> Result<(), Error> { // Handle empty tree case if position == 0 && leaves.is_empty() && self.siblings.is_empty() { @@ -481,7 +391,7 @@ impl RangeProof { } // Compute position-hashed leaves - let mut nodes: Vec> = Vec::new(); + let mut nodes: Vec> = Vec::new(); for (i, leaf) in leaves.iter().enumerate() { let leaf_position = position + i as u32; hasher.update(&leaf_position.to_be_bytes()); @@ -497,8 +407,8 @@ impl RangeProof { // Check if we should have a left sibling let first_pos = nodes[0].position; let last_pos = nodes[nodes.len() - 1].position; - let needs_left = first_pos % 2 == 1; - let needs_right = last_pos % 2 == 0; + let needs_left = !first_pos.is_multiple_of(2); + let needs_right = last_pos.is_multiple_of(2); if needs_left != bounds.left.is_some() { return Err(Error::UnalignedProof); } @@ -569,31 +479,344 @@ impl RangeProof { } } -impl Write for RangeProof { +impl Write for RangeProof { fn write(&self, writer: &mut impl BufMut) { self.siblings.write(writer); } } -impl Read for RangeProof { +impl Read for RangeProof { type Cfg = (); fn read_cfg(reader: &mut impl Buf, _: &Self::Cfg) -> Result { - let siblings = Vec::>::read_range(reader, ..=MAX_LEVELS)?; + let siblings = Vec::>::read_range(reader, ..=MAX_LEVELS)?; Ok(Self { siblings }) } } -impl EncodeSize for RangeProof { +impl EncodeSize for RangeProof { fn encode_size(&self) -> usize { self.siblings.encode_size() } } +/// A Merkle proof for multiple non-contiguous leaves in a Binary Merkle Tree. +/// +/// This proof type is more space-efficient than generating individual proofs +/// for each leaf because sibling nodes that are shared between multiple paths +/// are deduplicated. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Proof { + /// The total number of leaves in the tree. + pub leaf_count: u32, + /// The deduplicated sibling digests required to verify all elements, + /// ordered by their position in the tree (level-major, then index within level). + pub siblings: Vec, +} + +impl Default for Proof { + fn default() -> Self { + Self { + leaf_count: 0, + siblings: vec![], + } + } +} + +impl Write for Proof { + fn write(&self, writer: &mut impl BufMut) { + self.leaf_count.write(writer); + self.siblings.write(writer); + } +} + +impl Read for Proof { + /// The maximum number of items being proven. + /// + /// The upper bound on sibling hashes is derived as `max_items * MAX_LEVELS`. + type Cfg = usize; + + fn read_cfg( + reader: &mut impl Buf, + max_items: &Self::Cfg, + ) -> Result { + let leaf_count = u32::read(reader)?; + let max_siblings = max_items.saturating_mul(MAX_LEVELS); + let siblings = Vec::::read_range(reader, ..=max_siblings)?; + Ok(Self { + leaf_count, + siblings, + }) + } +} + +impl EncodeSize for Proof { + fn encode_size(&self) -> usize { + self.leaf_count.encode_size() + self.siblings.encode_size() + } +} + +#[cfg(feature = "arbitrary")] +impl arbitrary::Arbitrary<'_> for Proof +where + D: for<'a> arbitrary::Arbitrary<'a>, +{ + fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { + Ok(Self { + leaf_count: u.arbitrary()?, + siblings: u.arbitrary()?, + }) + } +} + +/// Returns the number of levels in a tree with `leaf_count` leaves. +/// A tree with 1 leaf has 1 level, a tree with 2 leaves has 2 levels, etc. +const fn levels_in_tree(leaf_count: u32) -> usize { + (u32::BITS - (leaf_count.saturating_sub(1)).leading_zeros() + 1) as usize +} + +/// Returns the sorted, deduplicated positions of siblings required to prove +/// inclusion of leaves at the given positions. +/// +/// Each position in the result is encoded as `(level, index)` where level 0 is the leaf level. +fn siblings_required_for_multi_proof( + leaf_count: u32, + positions: impl IntoIterator, +) -> Result, Error> { + // Validate positions and check for duplicates. + let mut current = BTreeSet::new(); + for pos in positions { + if pos >= leaf_count { + return Err(Error::InvalidPosition(pos)); + } + if !current.insert(pos as usize) { + return Err(Error::DuplicatePosition(pos)); + } + } + + if current.is_empty() { + return Err(Error::NoLeaves); + } + + // Track positions we can compute at each level and record missing siblings. + // This keeps the work proportional to the number of positions, not the tree size. + let mut sibling_positions = BTreeSet::new(); + let levels_count = levels_in_tree(leaf_count); + let mut level_size = leaf_count as usize; + for level in 0..levels_count - 1 { + for &index in ¤t { + let sibling_index = if index.is_multiple_of(2) { + if index + 1 < level_size { + index + 1 + } else { + index + } + } else { + index - 1 + }; + + if sibling_index != index && !current.contains(&sibling_index) { + sibling_positions.insert((level, sibling_index)); + } + } + + current = current.iter().map(|idx| idx / 2).collect(); + level_size = level_size.div_ceil(2); + } + + Ok(sibling_positions) +} + +impl Proof { + /// Verifies that a given `leaf` at `position` is included in a Binary Merkle Tree + /// with `root` using the provided `hasher`. + /// + /// The proof consists of sibling hashes stored from the leaf up to the root. At each + /// level, if the current node is a left child (even index), the sibling is combined + /// to the right; if it is a right child (odd index), the sibling is combined to the + /// left. + pub fn verify_element_inclusion>( + &self, + hasher: &mut H, + leaf: &D, + mut position: u32, + root: &D, + ) -> Result<(), Error> { + // Validate position + if position >= self.leaf_count { + return Err(Error::InvalidPosition(position)); + } + + // Compute the position-hashed leaf + hasher.update(&position.to_be_bytes()); + hasher.update(leaf); + let mut computed = hasher.finalize(); + + // Track level size to handle odd-sized levels + let mut level_size = self.leaf_count as usize; + let mut sibling_iter = self.siblings.iter(); + + // Traverse from leaf to root + while level_size > 1 { + // Check if this is the last node at an odd-sized level (no real sibling) + let is_last_odd = position.is_multiple_of(2) && position as usize + 1 >= level_size; + + let (left_node, right_node) = if is_last_odd { + // Node is duplicated - no sibling consumed from proof + (&computed, &computed) + } else if position.is_multiple_of(2) { + // Even position: sibling is to the right + let sibling = sibling_iter.next().ok_or(Error::UnalignedProof)?; + (&computed, sibling) + } else { + // Odd position: sibling is to the left + let sibling = sibling_iter.next().ok_or(Error::UnalignedProof)?; + (sibling, &computed) + }; + + // Compute the parent digest + hasher.update(left_node); + hasher.update(right_node); + computed = hasher.finalize(); + + // Move up the tree + position /= 2; + level_size = level_size.div_ceil(2); + } + + // Ensure all siblings were consumed + if sibling_iter.next().is_some() { + return Err(Error::UnalignedProof); + } + + if computed == *root { + Ok(()) + } else { + Err(Error::InvalidProof(computed.to_string(), root.to_string())) + } + } + + /// Verifies that the given `elements` at their respective positions are included + /// in a Binary Merkle Tree with `root`. + /// + /// Elements can be provided in any order; positions are sorted internally. + /// Duplicate positions will cause verification to fail. + pub fn verify_multi_inclusion>( + &self, + hasher: &mut H, + elements: &[(D, u32)], + root: &D, + ) -> Result<(), Error> { + // Handle empty case + if elements.is_empty() { + if self.leaf_count == 0 && self.siblings.is_empty() { + let empty_root = hasher.finalize(); + if empty_root == *root { + return Ok(()); + } else { + return Err(Error::InvalidProof( + empty_root.to_string(), + root.to_string(), + )); + } + } + return Err(Error::NoLeaves); + } + + // 1. Sort elements by position and check for duplicates/bounds + let mut sorted: Vec<(u32, D)> = Vec::with_capacity(elements.len()); + for (leaf, position) in elements { + if *position >= self.leaf_count { + return Err(Error::InvalidPosition(*position)); + } + hasher.update(&position.to_be_bytes()); + hasher.update(leaf); + sorted.push((*position, hasher.finalize())); + } + sorted.sort_unstable_by_key(|(pos, _)| *pos); + + // Check for duplicates (adjacent elements with same position after sorting) + for i in 1..sorted.len() { + if sorted[i - 1].0 == sorted[i].0 { + return Err(Error::DuplicatePosition(sorted[i].0)); + } + } + + // 2. Iterate up the tree + // Since we process left-to-right and parent_pos = pos/2, next_level stays sorted. + let levels = levels_in_tree(self.leaf_count); + let mut level_size = self.leaf_count; + let mut sibling_iter = self.siblings.iter(); + let mut current = sorted; + let mut next_level: Vec<(u32, D)> = Vec::with_capacity(current.len()); + + for _ in 0..levels - 1 { + let mut idx = 0; + while idx < current.len() { + let (pos, digest) = current[idx]; + let parent_pos = pos / 2; + + // Determine if we have the left or right child + let (left, right) = if pos % 2 == 0 { + // We are the LEFT child + let left = digest; + + // Check if we have the right child in our current set + let right = if idx + 1 < current.len() && current[idx + 1].0 == pos + 1 { + idx += 1; + current[idx].1 + } else if pos + 1 >= level_size { + // If no right child exists in tree, duplicate left + left + } else { + // Otherwise, must consume a sibling + *sibling_iter.next().ok_or(Error::UnalignedProof)? + }; + (left, right) + } else { + // We are the RIGHT child + // This implies the LEFT child was missing from 'current', so it must be a sibling. + let right = digest; + let left = *sibling_iter.next().ok_or(Error::UnalignedProof)?; + (left, right) + }; + + // Hash parent + hasher.update(&left); + hasher.update(&right); + next_level.push((parent_pos, hasher.finalize())); + + idx += 1; + } + + // Prepare for next level + core::mem::swap(&mut current, &mut next_level); + next_level.clear(); + level_size = level_size.div_ceil(2); + } + + // 3. Verify root + if sibling_iter.next().is_some() { + return Err(Error::UnalignedProof); + } + + if current.len() != 1 { + return Err(Error::UnalignedProof); + } + + let computed = current[0].1; + if computed == *root { + Ok(()) + } else { + Err(Error::InvalidProof(computed.to_string(), root.to_string())) + } + } +} + #[cfg(test)] mod tests { use super::*; - use commonware_codec::{DecodeExt, Encode}; + use commonware_codec::{Decode, DecodeExt, Encode}; use commonware_cryptography::sha256::{Digest, Sha256}; use commonware_utils::hex; use rstest::rstest; @@ -601,7 +824,7 @@ mod tests { fn test_merkle_tree(n: usize) -> Digest { // Build tree let mut digests = Vec::with_capacity(n); - let mut builder = Builder::new(n); + let mut builder = Builder::::new(n); for i in 0..n { let digest = Sha256::hash(&i.to_be_bytes()); builder.add(&digest); @@ -616,16 +839,18 @@ mod tests { // Generate proof let proof = tree.proof(i as u32).unwrap(); assert!( - proof.verify(&mut hasher, leaf, i as u32, &root).is_ok(), + proof + .verify_element_inclusion(&mut hasher, leaf, i as u32, &root) + .is_ok(), "correct fail for size={n} leaf={i}" ); // Serialize and deserialize the proof - let mut serialized = proof.encode(); - let deserialized = Proof::::decode(&mut serialized).unwrap(); + let serialized = proof.encode(); + let deserialized = Proof::::decode_cfg(serialized, &1).unwrap(); assert!( deserialized - .verify(&mut hasher, leaf, i as u32, &root) + .verify_element_inclusion(&mut hasher, leaf, i as u32, &root) .is_ok(), "deserialize fail for size={n} leaf={i}" ); @@ -636,7 +861,7 @@ mod tests { update_tamper.siblings[0] = Sha256::hash(b"tampered"); assert!( update_tamper - .verify(&mut hasher, leaf, i as u32, &root) + .verify_element_inclusion(&mut hasher, leaf, i as u32, &root) .is_err(), "modify fail for size={n} leaf={i}" ); @@ -647,7 +872,7 @@ mod tests { add_tamper.siblings.push(Sha256::hash(b"tampered")); assert!( add_tamper - .verify(&mut hasher, leaf, i as u32, &root) + .verify_element_inclusion(&mut hasher, leaf, i as u32, &root) .is_err(), "add fail for size={n} leaf={i}" ); @@ -658,7 +883,7 @@ mod tests { remove_tamper.siblings.pop(); assert!( remove_tamper - .verify(&mut hasher, leaf, i as u32, &root) + .verify_element_inclusion(&mut hasher, leaf, i as u32, &root) .is_err(), "remove fail for size={n} leaf={i}" ); @@ -895,7 +1120,7 @@ mod tests { let element = &digests[0]; // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -910,7 +1135,9 @@ mod tests { // Fail verification with an empty proof. let mut hasher = Sha256::default(); - assert!(proof.verify(&mut hasher, element, 0, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, element, 0, &root) + .is_err()); } #[test] @@ -921,7 +1148,7 @@ mod tests { let element = &digests[0]; // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -934,9 +1161,11 @@ mod tests { // Tamper with proof proof.siblings.push(*element); - // Fail verification with an empty proof. + // Fail verification with extra sibling let mut hasher = Sha256::default(); - assert!(proof.verify(&mut hasher, element, 0, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, element, 0, &root) + .is_err()); } #[test] @@ -946,7 +1175,7 @@ mod tests { let digests: Vec = txs.iter().map(|tx| Sha256::hash(*tx)).collect(); // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -959,7 +1188,9 @@ mod tests { // Use a wrong element (e.g. hash of a different transaction). let mut hasher = Sha256::default(); let wrong_leaf = Sha256::hash(b"wrong_tx"); - assert!(proof.verify(&mut hasher, &wrong_leaf, 2, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, &wrong_leaf, 2, &root) + .is_err()); } #[test] @@ -969,7 +1200,7 @@ mod tests { let digests: Vec = txs.iter().map(|tx| Sha256::hash(*tx)).collect(); // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -981,7 +1212,9 @@ mod tests { // Use an incorrect index (e.g. 2 instead of 1). let mut hasher = Sha256::default(); - assert!(proof.verify(&mut hasher, &digests[1], 2, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, &digests[1], 2, &root) + .is_err()); } #[test] @@ -991,7 +1224,7 @@ mod tests { let digests: Vec = txs.iter().map(|tx| Sha256::hash(*tx)).collect(); // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -1004,7 +1237,7 @@ mod tests { let mut hasher = Sha256::default(); let wrong_root = Sha256::hash(b"wrong_root"); assert!(proof - .verify(&mut hasher, &digests[0], 0, &wrong_root) + .verify_element_inclusion(&mut hasher, &digests[0], 0, &wrong_root) .is_err()); } @@ -1027,7 +1260,7 @@ mod tests { // Truncate one byte. serialized.truncate(serialized.len() - 1); - assert!(Proof::::decode(&mut serialized).is_err()); + assert!(Proof::::decode_cfg(&mut serialized, &1).is_err()); } #[test] @@ -1049,7 +1282,7 @@ mod tests { // Append an extra byte. serialized.extend_from_slice(&[0u8]); - assert!(Proof::::decode(&mut serialized).is_err()); + assert!(Proof::::decode_cfg(&mut serialized, &1).is_err()); } #[test] @@ -1059,7 +1292,7 @@ mod tests { let digests: Vec = txs.iter().map(|tx| Sha256::hash(*tx)).collect(); // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -1072,7 +1305,9 @@ mod tests { // Modify the first hash in the proof. let mut hasher = Sha256::default(); proof.siblings[0] = Sha256::hash(b"modified"); - assert!(proof.verify(&mut hasher, &digests[2], 2, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, &digests[2], 2, &root) + .is_err()); } #[test] @@ -1082,7 +1317,7 @@ mod tests { let digests: Vec = txs.iter().map(|tx| Sha256::hash(*tx)).collect(); // Build tree - let mut builder = Builder::new(txs.len()); + let mut builder = Builder::::new(txs.len()); for digest in &digests { builder.add(digest); } @@ -1094,14 +1329,18 @@ mod tests { // Verification should succeed for the proper index 2. let mut hasher = Sha256::default(); - assert!(proof.verify(&mut hasher, &digests[2], 2, &root).is_ok()); + assert!(proof + .verify_element_inclusion(&mut hasher, &digests[2], 2, &root) + .is_ok()); // Should not be able to generate a proof for an out-of-range index (e.g. 3). assert!(tree.proof(3).is_err()); // Attempting to verify using an out-of-range index (e.g. 3, which would correspond // to a duplicate leaf that doesn't actually exist) should fail. - assert!(proof.verify(&mut hasher, &digests[2], 3, &root).is_err()); + assert!(proof + .verify_element_inclusion(&mut hasher, &digests[2], 3, &root) + .is_err()); } #[test] @@ -1128,7 +1367,7 @@ mod tests { // Serialize and deserialize let mut serialized = range_proof.encode(); - let deserialized = RangeProof::::decode(&mut serialized).unwrap(); + let deserialized = RangeProof::::decode(&mut serialized).unwrap(); assert!(deserialized .verify(&mut hasher, 2, range_leaves, &root) .is_ok()); @@ -1588,9 +1827,9 @@ mod tests { #[test] fn test_empty_range_proof_serialization() { - let range_proof = RangeProof::::default(); + let range_proof = RangeProof::::default(); let mut serialized = range_proof.encode(); - let deserialized = RangeProof::::decode(&mut serialized).unwrap(); + let deserialized = RangeProof::::decode(&mut serialized).unwrap(); assert_eq!(range_proof, deserialized); } @@ -1712,6 +1951,918 @@ mod tests { .is_ok()); } + #[test] + fn test_multi_proof_basic() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + + // Test multi-proof for non-contiguous positions [0, 3, 5] + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + let mut hasher = Sha256::default(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + } + + #[test] + fn test_multi_proof_single_element() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test single element multi-proof for each position + for (i, digest) in digests.iter().enumerate() { + let multi_proof = tree.multi_proof(&[i as u32]).unwrap(); + let elements = [(*digest, i as u32)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok(), + "Failed for position {i}" + ); + } + } + + #[test] + fn test_multi_proof_all_elements() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test multi-proof for all elements + let positions: Vec = (0..digests.len() as u32).collect(); + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + + // When proving all elements, we shouldn't need any siblings (all can be computed) + assert!(multi_proof.siblings.is_empty()); + } + + #[test] + fn test_multi_proof_adjacent_elements() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test adjacent positions (should deduplicate shared siblings) + let positions = [2, 3]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + } + + #[test] + fn test_multi_proof_sparse_positions() { + // Create test data + let digests: Vec = (0..16u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test widely separated positions + let positions = [0, 7, 8, 15]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + } + + #[test] + fn test_multi_proof_empty_tree() { + // Build empty tree + let builder = Builder::::new(0); + let tree = builder.build(); + + // Empty tree with empty positions should return NoLeaves error + // (we can't prove zero elements) + assert!(matches!(tree.multi_proof(&[]), Err(Error::NoLeaves))); + + // Empty tree with any position should fail with InvalidPosition + assert!(matches!( + tree.multi_proof(&[0]), + Err(Error::InvalidPosition(0)) + )); + } + + #[test] + fn test_multi_proof_empty_positions() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Empty positions should return error + assert!(matches!(tree.multi_proof(&[]), Err(Error::NoLeaves))); + } + + #[test] + fn test_multi_proof_duplicate_positions_error() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Duplicate positions should return error + assert!(matches!( + tree.multi_proof(&[1, 1]), + Err(Error::DuplicatePosition(1)) + )); + assert!(matches!( + tree.multi_proof(&[0, 2, 2, 5]), + Err(Error::DuplicatePosition(2)) + )); + } + + #[test] + fn test_multi_proof_unsorted_input() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test with unsorted positions (should work - internal sorting) + let positions = [5, 0, 3]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Verify with unsorted elements (should work - internal sorting) + let unsorted_elements = [(digests[5], 5), (digests[0], 0), (digests[3], 3)]; + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &unsorted_elements, &root) + .is_ok()); + } + + #[test] + fn test_multi_proof_various_sizes() { + // Test multi-proofs for trees of various sizes + for tree_size in [1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32] { + let digests: Vec = (0..tree_size as u32) + .map(|i| Sha256::hash(&i.to_be_bytes())) + .collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test various position combinations + // First and last + if tree_size >= 2 { + let positions = [0, (tree_size - 1) as u32]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok(), + "Failed for tree_size={tree_size}, positions=[0, {}]", + tree_size - 1 + ); + } + + // Every other element + if tree_size >= 4 { + let positions: Vec = (0..tree_size as u32).step_by(2).collect(); + let multi_proof = tree.multi_proof(&positions).unwrap(); + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok(), + "Failed for tree_size={tree_size}, every other element" + ); + } + } + } + + #[test] + fn test_multi_proof_wrong_elements() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Verify with wrong elements + let wrong_elements = [ + (Sha256::hash(b"wrong1"), 0), + (digests[3], 3), + (digests[5], 5), + ]; + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &wrong_elements, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_wrong_positions() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Verify with wrong positions (same elements, different positions) + let wrong_positions = [ + (digests[0], 1), // wrong position + (digests[3], 3), + (digests[5], 5), + ]; + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &wrong_positions, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_wrong_root() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // Verify with wrong root + let wrong_root = Sha256::hash(b"wrong_root"); + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &wrong_root) + .is_err()); + } + + #[test] + fn test_multi_proof_tampering() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // Tamper with sibling + assert!(!multi_proof.siblings.is_empty()); + let mut modified = multi_proof.clone(); + modified.siblings[0] = Sha256::hash(b"tampered"); + assert!(modified + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err()); + + // Add extra sibling + let mut extra = multi_proof.clone(); + extra.siblings.push(Sha256::hash(b"extra")); + assert!(extra + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err()); + + // Remove a sibling + let mut missing = multi_proof; + missing.siblings.pop(); + assert!(missing + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_deduplication() { + // Create test data + let digests: Vec = (0..16u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Get individual proofs + let individual_siblings: usize = [0u32, 1, 8, 9] + .iter() + .map(|&p| tree.proof(p).unwrap().siblings.len()) + .sum(); + + // Get multi-proof for same positions + let multi_proof = tree.multi_proof(&[0, 1, 8, 9]).unwrap(); + + // Multi-proof should have fewer siblings due to deduplication + assert!( + multi_proof.siblings.len() < individual_siblings, + "Multi-proof ({}) should have fewer siblings than sum of individual proofs ({})", + multi_proof.siblings.len(), + individual_siblings + ); + } + + #[test] + fn test_multi_proof_serialization() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Serialize and deserialize + let serialized = multi_proof.encode(); + let deserialized = Proof::::decode_cfg(serialized, &positions.len()).unwrap(); + + assert_eq!(multi_proof, deserialized); + + // Verify deserialized proof works + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!(deserialized + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok()); + } + + #[test] + fn test_multi_proof_serialization_truncated() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Generate proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Serialize and truncate + let mut serialized = multi_proof.encode(); + serialized.truncate(serialized.len() - 1); + + // Should fail to deserialize + assert!(Proof::::decode_cfg(&mut serialized, &positions.len()).is_err()); + } + + #[test] + fn test_multi_proof_serialization_extra() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Generate proof + let positions = [0, 3, 5]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Serialize and add extra byte + let mut serialized = multi_proof.encode_mut(); + serialized.extend_from_slice(&[0u8]); + + // Should fail to deserialize + assert!(Proof::::decode_cfg(&mut serialized, &positions.len()).is_err()); + } + + #[test] + fn test_multi_proof_decode_insufficient_data() { + let mut serialized = Vec::new(); + serialized.extend_from_slice(&8u32.encode()); // leaf_count + serialized.extend_from_slice(&1usize.encode()); // claims 1 sibling but no data follows + + // Should fail because the buffer claims 1 sibling but doesn't have the data + let err = Proof::::decode_cfg(serialized.as_slice(), &1).unwrap_err(); + assert!(matches!(err, commonware_codec::Error::EndOfBuffer)); + } + + #[test] + fn test_multi_proof_invalid_position() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + + // Test out of bounds position + assert!(matches!( + tree.multi_proof(&[0, 8]), + Err(Error::InvalidPosition(8)) + )); + assert!(matches!( + tree.multi_proof(&[100]), + Err(Error::InvalidPosition(100)) + )); + } + + #[test] + fn test_multi_proof_verify_invalid_position() { + // Create test data + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 3]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Try to verify with out of bounds position + let invalid_elements = [(digests[0], 0), (digests[3], 100)]; + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &invalid_elements, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_odd_tree_sizes() { + // Test odd-sized trees that require node duplication + for tree_size in [3, 5, 7, 9, 11, 13, 15] { + let digests: Vec = (0..tree_size as u32) + .map(|i| Sha256::hash(&i.to_be_bytes())) + .collect(); + + // Build tree + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Test with positions including the last element + let positions = [0, (tree_size - 1) as u32]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok(), + "Failed for tree_size={tree_size}" + ); + } + } + + #[test] + fn test_multi_proof_verify_empty_elements() { + // Create a valid proof and try to verify with empty elements + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 3]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Try to verify with empty elements + let empty_elements: &[(Digest, u32)] = &[]; + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, empty_elements, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_default_verify() { + // Default (empty) proof should only verify against empty tree + let mut hasher = Sha256::default(); + let default_proof = Proof::::default(); + + // Empty elements against default proof + let empty_elements: &[(Digest, u32)] = &[]; + + // Build empty tree to get the empty root + let builder = Builder::::new(0); + let empty_tree = builder.build(); + let empty_root = empty_tree.root(); + + assert!(default_proof + .verify_multi_inclusion(&mut hasher, empty_elements, &empty_root) + .is_ok()); + + // Should fail with wrong root + let wrong_root = Sha256::hash(b"not_empty"); + assert!(default_proof + .verify_multi_inclusion(&mut hasher, empty_elements, &wrong_root) + .is_err()); + } + + #[test] + fn test_multi_proof_single_leaf_tree() { + // Edge case: tree with exactly one leaf + let digest = Sha256::hash(b"only_leaf"); + + // Build single-leaf tree + let mut builder = Builder::::new(1); + builder.add(&digest); + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate multi-proof for the only leaf + let multi_proof = tree.multi_proof(&[0]).unwrap(); + + // Single leaf tree: leaf_count should be 1 + assert_eq!(multi_proof.leaf_count, 1); + + // Single leaf tree: no siblings needed (leaf is the root after position hashing) + assert!( + multi_proof.siblings.is_empty(), + "Single leaf tree should have no siblings" + ); + + // Verify the proof + let elements = [(digest, 0u32)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_ok(), + "Single leaf multi-proof verification failed" + ); + + // Verify with wrong digest fails + let wrong_digest = Sha256::hash(b"wrong"); + let wrong_elements = [(wrong_digest, 0u32)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &wrong_elements, &root) + .is_err(), + "Should fail with wrong digest" + ); + + // Verify with wrong position fails + let wrong_position_elements = [(digest, 1u32)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &wrong_position_elements, &root) + .is_err(), + "Should fail with invalid position" + ); + } + + #[test] + fn test_multi_proof_malicious_leaf_count_zero() { + // Attacker sets leaf_count = 0 but provides siblings + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof and tamper with leaf_count + let positions = [0, 3]; + let mut multi_proof = tree.multi_proof(&positions).unwrap(); + multi_proof.leaf_count = 0; + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // Should fail - leaf_count=0 but we have elements + assert!(multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err()); + } + + #[test] + fn test_multi_proof_malicious_leaf_count_larger() { + // Attacker inflates leaf_count to claim proof is for larger tree + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof and inflate leaf_count + let positions = [0, 3]; + let mut multi_proof = tree.multi_proof(&positions).unwrap(); + let original_leaf_count = multi_proof.leaf_count; + multi_proof.leaf_count = 1000; + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // Should fail - inflated leaf_count changes required siblings + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err(), + "Should reject proof with inflated leaf_count ({} -> {})", + original_leaf_count, + multi_proof.leaf_count + ); + } + + #[test] + fn test_multi_proof_malicious_leaf_count_smaller() { + // Attacker deflates leaf_count to claim proof is for smaller tree + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof and deflate leaf_count + let positions = [0, 3]; + let mut multi_proof = tree.multi_proof(&positions).unwrap(); + multi_proof.leaf_count = 4; // Smaller than actual tree + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // Should fail - deflated leaf_count changes tree structure + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err(), + "Should reject proof with deflated leaf_count" + ); + } + + #[test] + fn test_multi_proof_mismatched_element_count() { + // Provide more or fewer elements than the proof was generated for + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate proof for 2 positions + let positions = [0, 3]; + let multi_proof = tree.multi_proof(&positions).unwrap(); + + // Try to verify with only 1 element (too few) + let too_few = [(digests[0], 0u32)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &too_few, &root) + .is_err(), + "Should reject when fewer elements provided than proof was generated for" + ); + + // Try to verify with 3 elements (too many) + let too_many = [(digests[0], 0u32), (digests[3], 3), (digests[5], 5)]; + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &too_many, &root) + .is_err(), + "Should reject when more elements provided than proof was generated for" + ); + } + + #[test] + fn test_multi_proof_swapped_siblings() { + // Swap the order of siblings in the proof + let digests: Vec = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof with multiple siblings + let positions = [0, 5]; + let mut multi_proof = tree.multi_proof(&positions).unwrap(); + + // Ensure we have at least 2 siblings to swap + if multi_proof.siblings.len() >= 2 { + // Swap first two siblings + multi_proof.siblings.swap(0, 1); + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + assert!( + multi_proof + .verify_multi_inclusion(&mut hasher, &elements, &root) + .is_err(), + "Should reject proof with swapped siblings" + ); + } + } + + #[test] + fn test_multi_proof_dos_large_leaf_count() { + // Attacker sets massive leaf_count trying to cause DoS via memory allocation + // The verify function should NOT allocate proportional to leaf_count + let digests: Vec = (0..4u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect(); + + let mut builder = Builder::::new(digests.len()); + for digest in &digests { + builder.add(digest); + } + let tree = builder.build(); + let root = tree.root(); + let mut hasher = Sha256::default(); + + // Generate valid proof + let positions = [0, 2]; + let mut multi_proof = tree.multi_proof(&positions).unwrap(); + + // Set massive leaf_count (attacker trying to exhaust memory) + multi_proof.leaf_count = u32::MAX; + + let elements: Vec<(Digest, u32)> = positions + .iter() + .map(|&p| (digests[p as usize], p)) + .collect(); + + // This should fail quickly without allocating massive memory + // The function is O(elements * levels), not O(leaf_count) + let result = multi_proof.verify_multi_inclusion(&mut hasher, &elements, &root); + assert!(result.is_err(), "Should reject malicious large leaf_count"); + } + #[cfg(feature = "arbitrary")] mod conformance { use super::*; @@ -1719,9 +2870,9 @@ mod tests { use commonware_cryptography::sha256::Digest as Sha256Digest; commonware_conformance::conformance_tests! { - CodecConformance>, - CodecConformance>, + CodecConformance>, CodecConformance>, + CodecConformance>, } } } diff --git a/storage/src/mmr/mod.rs b/storage/src/mmr/mod.rs index fe820a688c..e0ed2779c6 100644 --- a/storage/src/mmr/mod.rs +++ b/storage/src/mmr/mod.rs @@ -87,7 +87,7 @@ cfg_if::cfg_if! { pub use hasher::Standard as StandardHasher; pub use location::{Location, LocationError, MAX_LOCATION}; pub use position::{Position, MAX_POSITION}; -pub use proof::Proof; +pub use proof::{Proof, MAX_PROOF_DIGESTS_PER_ELEMENT}; use thiserror::Error; /// Errors that can occur when interacting with an MMR. diff --git a/storage/src/mmr/proof.rs b/storage/src/mmr/proof.rs index 0481c6e112..bf41c26652 100644 --- a/storage/src/mmr/proof.rs +++ b/storage/src/mmr/proof.rs @@ -23,6 +23,13 @@ use core::ops::Range; #[cfg(feature = "std")] use tracing::debug; +/// The maximum number of digests in a proof per element being proven. +/// +/// This accounts for the worst case proof size, in an MMR with 62 peaks. The +/// left-most leaf in such a tree requires 122 digests, for 61 path siblings +/// and 61 peak digests. +pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122; + /// Errors that can occur when reconstructing a digest from a proof due to invalid input. #[derive(Error, Debug)] pub enum ReconstructionError { @@ -85,16 +92,21 @@ impl Write for Proof { } impl Read for Proof { - /// The maximum number of digests in the proof. + /// The maximum number of items being proven. + /// + /// The upper bound on digests is derived as `max_items * MAX_PROOF_DIGESTS_PER_ELEMENT`. type Cfg = usize; - fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result { + fn read_cfg( + buf: &mut impl Buf, + max_items: &Self::Cfg, + ) -> Result { // Read the number of nodes in the MMR let size = Position::new(UInt::::read(buf)?.into()); // Read the digests - let range = ..=max_len; - let digests = Vec::::read_range(buf, range)?; + let max_digests = max_items.saturating_mul(MAX_PROOF_DIGESTS_PER_ELEMENT); + let digests = Vec::::read_range(buf, ..=max_digests)?; Ok(Self { size, digests }) } } @@ -1089,8 +1101,9 @@ mod tests { expected_size, "serialized proof should have expected size" ); - let max_digests = proof.digests.len(); - let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_digests).unwrap(); + // max_items is the number of elements in the range + let max_items = j - i; + let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_items).unwrap(); assert_eq!( proof, deserialized_proof, "deserialized proof should match source proof" @@ -1101,7 +1114,7 @@ mod tests { let serialized_proof = proof.encode(); let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1); assert!( - Proof::::decode_cfg(serialized_proof, &max_digests).is_err(), + Proof::::decode_cfg(serialized_proof, &max_items).is_err(), "proof should not deserialize with truncated data" ); @@ -1112,16 +1125,21 @@ mod tests { let serialized_proof = serialized_proof; assert!( - Proof::::decode_cfg(serialized_proof, &max_digests).is_err(), + Proof::::decode_cfg(serialized_proof, &max_items).is_err(), "proof should not deserialize with extra data" ); - // Confirm deserialization fails when max length is exceeded. - if max_digests > 0 { + // Confirm deserialization fails when max_items is too small. + let actual_digests = proof.digests.len(); + if actual_digests > 0 { + // Find the minimum max_items that would allow this many digests + let min_max_items = actual_digests.div_ceil(MAX_PROOF_DIGESTS_PER_ELEMENT); + // Using one less should fail + let too_small = min_max_items - 1; let serialized_proof = proof.encode(); assert!( - Proof::::decode_cfg(serialized_proof, &(max_digests - 1)).is_err(), - "proof should not deserialize with max length exceeded" + Proof::::decode_cfg(serialized_proof, &too_small).is_err(), + "proof should not deserialize with max_items too small" ); } } @@ -1626,6 +1644,109 @@ mod tests { } } + #[test] + fn test_max_proof_digests_per_element_sufficient() { + // Verify that MAX_PROOF_DIGESTS_PER_ELEMENT (122) is sufficient for any single-element + // proof in the largest valid MMR. + // + // MMR sizes follow: mmr_size(N) = 2*N - popcount(N) where N = leaf count. + // The number of peaks equals popcount(N). + // + // To maximize peaks, we want N with maximum popcount. N = 2^62 - 1 has 62 one-bits: + // N = 0x3FFFFFFFFFFFFFFF = 2^0 + 2^1 + ... + 2^61 + // + // This gives us 62 perfect binary trees with leaf counts 2^0, 2^1, ..., 2^61 + // and corresponding heights 0, 1, ..., 61. + // + // mmr_size(2^62 - 1) = 2*(2^62 - 1) - 62 = 2^63 - 2 - 62 = 2^63 - 64 + // + // For a single-element proof in a tree of height h: + // - Path siblings from leaf to peak: h digests + // - Other peaks (not containing the element): (62 - 1) = 61 digests + // - Total: h + 61 digests + // + // Worst case: element in tallest tree (h = 61) + // - Path siblings: 61 + // - Other peaks: 61 + // - Total: 61 + 61 = 122 digests + + const NUM_PEAKS: usize = 62; + const MAX_TREE_HEIGHT: usize = 61; + const EXPECTED_WORST_CASE: usize = MAX_TREE_HEIGHT + (NUM_PEAKS - 1); + + let many_peaks_size = Position::new((1u64 << 63) - 64); + assert!( + many_peaks_size.is_mmr_size(), + "Size {many_peaks_size} should be a valid MMR size", + ); + + let peak_count = PeakIterator::new(many_peaks_size).count(); + assert_eq!(peak_count, NUM_PEAKS); + + // Verify the peak heights are 61, 60, ..., 1, 0 (from left to right) + let peaks: Vec<_> = PeakIterator::new(many_peaks_size).collect(); + for (i, &(_pos, height)) in peaks.iter().enumerate() { + let expected_height = (NUM_PEAKS - 1 - i) as u32; + assert_eq!( + height, expected_height, + "Peak {i} should have height {expected_height}, got {height}", + ); + } + + // Test location 0 (leftmost leaf, in tallest tree of height 61) + // Expected: 61 path siblings + 61 other peaks = 122 digests + let loc = Location::new_unchecked(0); + let positions = nodes_required_for_range_proof(many_peaks_size, loc..loc + 1) + .expect("should compute positions for location 0"); + + assert_eq!( + positions.len(), + EXPECTED_WORST_CASE, + "Location 0 proof should require exactly {EXPECTED_WORST_CASE} digests (61 path + 61 peaks)", + ); + + // Test the rightmost leaf (in smallest tree of height 0, which is itself a peak) + // Expected: 0 path siblings + 61 other peaks = 61 digests + let last_leaf_loc = (1u64 << 62) - 2; // Last leaf location + let positions = nodes_required_for_range_proof( + many_peaks_size, + Location::new_unchecked(last_leaf_loc)..Location::new_unchecked(last_leaf_loc + 1), + ) + .expect("should compute positions for last leaf"); + + let expected_last_leaf = NUM_PEAKS - 1; + assert_eq!( + positions.len(), + expected_last_leaf, + "Last leaf proof should require exactly {expected_last_leaf} digests (0 path + 61 peaks)", + ); + } + + #[test] + fn test_max_proof_digests_per_element_is_maximum() { + // For K peaks, the worst-case proof needs: (max_tree_height) + (K - 1) digests + // With K peaks of heights K-1, K-2, ..., 0, this is (K-1) + (K-1) = 2*(K-1) + // + // To get K peaks, leaf count N must have exactly K bits set. + // MMR size = 2*N - popcount(N) = 2*N - K + // + // For 63 peaks: N = 2^63 - 1 (63 bits set), size = 2*(2^63 - 1) - 63 = 2^64 - 65 + // This exceeds MAX_POSITION, so is_mmr_size() returns false. + + let n_for_63_peaks = (1u128 << 63) - 1; + let size_for_63_peaks = 2 * n_for_63_peaks - 63; // = 2^64 - 65 + assert!( + size_for_63_peaks > *crate::mmr::MAX_POSITION as u128, + "63 peaks requires size {size_for_63_peaks} > MAX_POSITION", + ); + + let size_truncated = size_for_63_peaks as u64; + assert!( + !Position::new(size_truncated).is_mmr_size(), + "Size for 63 peaks should fail is_mmr_size()" + ); + } + #[cfg(feature = "arbitrary")] mod conformance { use super::*;