diff --git a/Cargo.lock b/Cargo.lock index 14e6d3b692..4a136964b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1274,6 +1274,8 @@ dependencies = [ "ecdsa 0.16.9", "ed25519-consensus", "getrandom 0.2.16", + "num-rational", + "num-traits", "p256 0.13.2", "proptest", "rand 0.8.5", @@ -1301,6 +1303,7 @@ dependencies = [ "commonware-utils", "ed25519-zebra", "libfuzzer-sys", + "num-rational", "p256 0.13.2", "rand 0.8.5", "sha2 0.10.9", diff --git a/cryptography/Cargo.toml b/cryptography/Cargo.toml index a6ed333390..066b4db74b 100644 --- a/cryptography/Cargo.toml +++ b/cryptography/Cargo.toml @@ -29,6 +29,8 @@ crc-fast = { workspace = true, features = ["panic-handler"] } ctutils.workspace = true ecdsa.workspace = true ed25519-consensus = { workspace = true, default-features = false } +num-rational = { workspace = true, optional = true } +num-traits = { workspace = true, optional = true } p256 = { workspace = true, features = ["ecdsa"] } rand.workspace = true rand_chacha.workspace = true @@ -87,6 +89,9 @@ std = [ "ecdsa/std", "ed25519-consensus/std", "getrandom/std", + "num-rational", + "num-traits", + "num-traits?/std", "p256/std", "rand/std", "rand/std_rng", @@ -131,3 +136,8 @@ path = "src/lthash/benches/bench.rs" name = "handshake" harness = false path = "src/handshake/benches/bench.rs" + +[[bench]] +name = "bloomfilter" +harness = false +path = "src/bloomfilter/benches/bench.rs" diff --git a/cryptography/conformance.toml b/cryptography/conformance.toml index 5872b0487b..dfb802e933 100644 --- a/cryptography/conformance.toml +++ b/cryptography/conformance.toml @@ -2,9 +2,13 @@ n_cases = 65536 hash = "d01e3ef0dd81919abd3f149649940ae2e8cff6dc4aa9f927d544aece387b12d1" -["commonware_cryptography::bloomfilter::tests::conformance::CodecConformance"] +["commonware_cryptography::bloomfilter::conformance::CodecConformance"] n_cases = 65536 -hash = "a75d6312366816126114abdc9d7fbd246891cab210eeae6a02781625aa1ad6a4" +hash = "064cd9eab1a79b2270de039d6105df261954aa1fc8df1fbd55ecf6f18f43f437" + +["commonware_cryptography::bloomfilter::conformance::RationalOptimalBits"] +n_cases = 1024 +hash = "f631238a016fb44fd2dee9eeb5bd20fc5052f33434810599b8ba6535e2ad8e97" ["commonware_cryptography::bls12381::certificate::multisig::tests::conformance::CodecConformance>"] n_cases = 65536 diff --git a/cryptography/fuzz/Cargo.toml b/cryptography/fuzz/Cargo.toml index d95e5ec351..f5f594f2b7 100644 --- a/cryptography/fuzz/Cargo.toml +++ b/cryptography/fuzz/Cargo.toml @@ -16,9 +16,10 @@ commonware-codec = { workspace = true, features = ["std"] } commonware-cryptography = { workspace = true, features = ["std", "arbitrary"] } commonware-math.workspace = true commonware-parallel = { workspace = true, features = ["std"] } -commonware-utils = { workspace = true, features = ["std", "arbitrary"] } +commonware-utils = { workspace = true, features = ["std"] } ed25519-zebra.workspace = true libfuzzer-sys.workspace = true +num-rational.workspace = true p256 = { workspace = true, features = ["ecdsa"] } rand.workspace = true sha2.workspace = true diff --git a/cryptography/fuzz/fuzz_targets/bloomfilter.rs b/cryptography/fuzz/fuzz_targets/bloomfilter.rs index 827ef51f18..28403c34a2 100644 --- a/cryptography/fuzz/fuzz_targets/bloomfilter.rs +++ b/cryptography/fuzz/fuzz_targets/bloomfilter.rs @@ -2,11 +2,13 @@ use arbitrary::Arbitrary; use commonware_codec::{Decode, Encode, EncodeSize}; -use commonware_cryptography::BloomFilter; +use commonware_cryptography::{sha256::Sha256, BloomFilter}; +use commonware_utils::rational::BigRationalExt; use libfuzzer_sys::fuzz_target; +use num_rational::BigRational; use std::{ collections::HashSet, - num::{NonZeroU16, NonZeroU8}, + num::{NonZeroU16, NonZeroU8, NonZeroUsize}, }; #[derive(Arbitrary, Debug)] @@ -18,30 +20,80 @@ enum Op { EncodeSize, } +#[derive(Debug)] +enum Constructor { + New { + hashers: NonZeroU8, + bits: NonZeroU16, + }, + WithRate { + expected_items: NonZeroU16, + fp_numerator: u64, + fp_denominator: u64, + }, +} + +impl<'a> Arbitrary<'a> for Constructor { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + if u.arbitrary::()? { + let hashers = u.arbitrary()?; + // Fallback to highest power of two in u16 on overflow + let bits = u + .arbitrary::()? + .checked_next_power_of_two() + .and_then(NonZeroU16::new) + .unwrap_or(NonZeroU16::new(1 << 15).unwrap()); + Ok(Constructor::New { hashers, bits }) + } else { + let expected_items = u.arbitrary::()?; + // Generate FP rate as rational: numerator in [1, denominator-1] to ensure (0, 1) + let fp_denominator = u.int_in_range(2u64..=10_000)?; + let fp_numerator = u.int_in_range(1u64..=fp_denominator - 1)?; + Ok(Constructor::WithRate { + expected_items, + fp_numerator, + fp_denominator, + }) + } + } +} + const MAX_OPERATIONS: usize = 64; #[derive(Debug)] struct FuzzInput { - hashers: NonZeroU8, - bits: NonZeroU16, + constructor: Constructor, ops: Vec, } impl<'a> Arbitrary<'a> for FuzzInput { fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { - let hashers = u.arbitrary()?; - let bits = u.arbitrary()?; + let constructor = u.arbitrary()?; let num_ops = u.int_in_range(1..=MAX_OPERATIONS)?; let ops = (0..num_ops) .map(|_| Op::arbitrary(u)) .collect::, _>>()?; - Ok(FuzzInput { hashers, bits, ops }) + Ok(FuzzInput { constructor, ops }) } } fn fuzz(input: FuzzInput) { - let cfg = (input.hashers, input.bits.into()); - let mut bf = BloomFilter::new(input.hashers, input.bits.into()); + let mut bf = match input.constructor { + Constructor::New { hashers, bits } => BloomFilter::::new(hashers, bits.into()), + Constructor::WithRate { + expected_items, + fp_numerator, + fp_denominator, + } => { + let fp_rate = BigRational::from_frac_u64(fp_numerator, fp_denominator); + BloomFilter::::with_rate( + NonZeroUsize::new(expected_items.get() as usize).unwrap(), + fp_rate, + ) + } + }; + + let cfg = (bf.hashers(), bf.bits().try_into().unwrap()); let mut model: HashSet> = HashSet::new(); for op in input.ops { @@ -58,11 +110,11 @@ fn fuzz(input: FuzzInput) { } Op::DecodeCfg(data, hashers, bits) => { let cfg = (hashers, bits.into()); - _ = BloomFilter::decode_cfg(&data[..], &cfg); + _ = BloomFilter::::decode_cfg(&data[..], &cfg); } Op::Encode(_item) => { let encoded = bf.encode(); - let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg).unwrap(); + let decoded = BloomFilter::::decode_cfg(encoded.clone(), &cfg).unwrap(); assert_eq!(bf, decoded); let encode_size = bf.encode_size(); @@ -80,7 +132,7 @@ fn fuzz(input: FuzzInput) { "encode_size should match encode().len()" ); - let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap(); + let decoded = BloomFilter::::decode_cfg(encoded, &cfg).unwrap(); assert_eq!(bf, decoded); assert_eq!(decoded.encode_size(), size1); diff --git a/cryptography/src/bloomfilter.rs b/cryptography/src/bloomfilter.rs deleted file mode 100644 index 5859b18558..0000000000 --- a/cryptography/src/bloomfilter.rs +++ /dev/null @@ -1,274 +0,0 @@ -//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter). - -use crate::{ - sha256::{Digest, Sha256}, - Hasher, -}; -use bytes::{Buf, BufMut}; -use commonware_codec::{ - codec::{Read, Write}, - error::Error as CodecError, - EncodeSize, FixedSize, -}; -use commonware_utils::bitmap::BitMap; -use core::num::{NonZeroU64, NonZeroU8, NonZeroUsize}; - -/// The length of a half of a [Digest]. -const HALF_DIGEST_LEN: usize = 16; - -/// The length of a full [Digest]. -const FULL_DIGEST_LEN: usize = Digest::SIZE; - -/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter). -/// -/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions -/// from two hash values, which are in turn derived from a single [Digest]. This provides -/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct BloomFilter { - hashers: u8, - bits: BitMap, -} - -impl BloomFilter { - /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits. - pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self { - Self { - hashers: hashers.get(), - bits: BitMap::zeroes(bits.get() as u64), - } - } - - /// Generate `num_hashers` bit indices for a given item. - fn indices(&self, item: &[u8], bits: u64) -> impl Iterator { - // Extract two 128-bit hash values from the SHA256 digest of the item - let digest = Sha256::hash(item); - let mut h1_bytes = [0u8; HALF_DIGEST_LEN]; - h1_bytes.copy_from_slice(&digest[0..HALF_DIGEST_LEN]); - let h1 = u128::from_be_bytes(h1_bytes); - let mut h2_bytes = [0u8; HALF_DIGEST_LEN]; - h2_bytes.copy_from_slice(&digest[HALF_DIGEST_LEN..FULL_DIGEST_LEN]); - let h2 = u128::from_be_bytes(h2_bytes); - - // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization: - // - // `h_i(x) = (h1(x) + i * h2(x)) mod m` - let hashers = self.hashers as u128; - let bits = bits as u128; - (0..hashers) - .map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) % bits) - .map(|index| index as u64) - } - - /// Inserts an item into the [BloomFilter]. - pub fn insert(&mut self, item: &[u8]) { - let indices = self.indices(item, self.bits.len()); - for index in indices { - self.bits.set(index, true); - } - } - - /// Checks if an item is possibly in the [BloomFilter]. - /// - /// Returns `true` if the item is probably in the set, and `false` if it is definitely not. - pub fn contains(&self, item: &[u8]) -> bool { - let indices = self.indices(item, self.bits.len()); - for index in indices { - if !self.bits.get(index) { - return false; - } - } - true - } -} - -impl Write for BloomFilter { - fn write(&self, buf: &mut impl BufMut) { - self.hashers.write(buf); - self.bits.write(buf); - } -} - -impl Read for BloomFilter { - // The number of hashers and the number of bits that the bitmap must have. - type Cfg = (NonZeroU8, NonZeroU64); - - fn read_cfg( - buf: &mut impl Buf, - (hashers_cfg, bits_cfg): &Self::Cfg, - ) -> Result { - let hashers = u8::read_cfg(buf, &())?; - if hashers != hashers_cfg.get() { - return Err(CodecError::Invalid( - "BloomFilter", - "hashers doesn't match config", - )); - } - let bits = BitMap::read_cfg(buf, &bits_cfg.get())?; - if bits.len() != bits_cfg.get() { - return Err(CodecError::Invalid( - "BloomFilter", - "bitmap length doesn't match config", - )); - } - Ok(Self { hashers, bits }) - } -} - -impl EncodeSize for BloomFilter { - fn encode_size(&self) -> usize { - self.hashers.encode_size() + self.bits.encode_size() - } -} - -#[cfg(feature = "arbitrary")] -impl arbitrary::Arbitrary<'_> for BloomFilter { - fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { - let hashers = u8::arbitrary(u)?; - // Ensure at least 1 bit to avoid empty bitmap - let bits_len = u.arbitrary_len::()?.max(1); - let mut bits = BitMap::with_capacity(bits_len as u64); - for _ in 0..bits_len { - bits.push(u.arbitrary::()?); - } - Ok(Self { hashers, bits }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use commonware_codec::{Decode, Encode}; - use commonware_utils::{NZUsize, NZU64, NZU8}; - - #[test] - fn test_insert_and_contains() { - let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(1000)); - let item1 = b"hello"; - let item2 = b"world"; - let item3 = b"bloomfilter"; - - bf.insert(item1); - bf.insert(item2); - - assert!(bf.contains(item1)); - assert!(bf.contains(item2)); - assert!(!bf.contains(item3)); - } - - #[test] - fn test_empty() { - let bf = BloomFilter::new(NZU8!(5), NZUsize!(100)); - assert!(!bf.contains(b"anything")); - } - - #[test] - fn test_false_positives() { - let mut bf = BloomFilter::new(NZU8!(10), NZUsize!(100)); - for i in 0..10usize { - bf.insert(&i.to_be_bytes()); - } - - // Check for inserted items - for i in 0..10usize { - assert!(bf.contains(&i.to_be_bytes())); - } - - // Check for non-inserted items and count false positives - let mut false_positives = 0; - for i in 100..1100usize { - if bf.contains(&i.to_be_bytes()) { - false_positives += 1; - } - } - - // A small bloom filter with many items will have some false positives. - // The exact number is probabilistic, but it should not be zero and not all should be FPs. - assert!(false_positives > 0); - assert!(false_positives < 1000); - } - - #[test] - fn test_codec_roundtrip() { - let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100)); - bf.insert(b"test1"); - bf.insert(b"test2"); - - let cfg = (NZU8!(5), NZU64!(100)); - - let encoded = bf.encode(); - let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap(); - - assert_eq!(bf, decoded); - } - - #[test] - fn test_codec_empty() { - let bf = BloomFilter::new(NZU8!(4), NZUsize!(128)); - let cfg = (NZU8!(4), NZU64!(128)); - let encoded = bf.encode(); - let decoded = BloomFilter::decode_cfg(encoded, &cfg).unwrap(); - assert_eq!(bf, decoded); - } - - #[test] - fn test_codec_with_invalid_hashers() { - let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100)); - bf.insert(b"test1"); - let encoded = bf.encode(); - - // Too large - let cfg = (NZU8!(10), NZU64!(100)); - let decoded = BloomFilter::decode_cfg(encoded.clone(), &cfg); - assert!(matches!( - decoded, - Err(CodecError::Invalid( - "BloomFilter", - "hashers doesn't match config" - )) - )); - - // Too small - let cfg = (NZU8!(4), NZU64!(100)); - let decoded = BloomFilter::decode_cfg(encoded, &cfg); - assert!(matches!( - decoded, - Err(CodecError::Invalid( - "BloomFilter", - "hashers doesn't match config" - )) - )); - } - - #[test] - fn test_codec_with_invalid_bits() { - let mut bf = BloomFilter::new(NZU8!(5), NZUsize!(100)); - bf.insert(b"test1"); - let encoded = bf.encode(); - - // Wrong bit count - let cfg = (NZU8!(5), NZU64!(99)); - let result = BloomFilter::decode_cfg(encoded.clone(), &cfg); - assert!(matches!(result, Err(CodecError::InvalidLength(100)))); - - let cfg = (NZU8!(5), NZU64!(101)); - let result = BloomFilter::decode_cfg(encoded, &cfg); - assert!(matches!( - result, - Err(CodecError::Invalid( - "BloomFilter", - "bitmap length doesn't match config" - )) - )); - } - - #[cfg(feature = "arbitrary")] - mod conformance { - use super::*; - use commonware_codec::conformance::CodecConformance; - - commonware_conformance::conformance_tests! { - CodecConformance, - } - } -} diff --git a/cryptography/src/bloomfilter/benches/bench.rs b/cryptography/src/bloomfilter/benches/bench.rs new file mode 100644 index 0000000000..3540f3f76f --- /dev/null +++ b/cryptography/src/bloomfilter/benches/bench.rs @@ -0,0 +1,6 @@ +use criterion::criterion_main; + +mod contains; +mod insert; + +criterion_main!(insert::benches, contains::benches); diff --git a/cryptography/src/bloomfilter/benches/contains.rs b/cryptography/src/bloomfilter/benches/contains.rs new file mode 100644 index 0000000000..2ab80e0e6c --- /dev/null +++ b/cryptography/src/bloomfilter/benches/contains.rs @@ -0,0 +1,86 @@ +use commonware_cryptography::{blake3::Blake3, sha256::Sha256, BloomFilter, Hasher}; +use commonware_utils::rational::BigRationalExt; +use criterion::{criterion_group, Criterion}; +use num_rational::BigRational; +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use std::{collections::HashSet, hint::black_box, num::NonZeroUsize}; + +const ITEM_SIZES: [usize; 3] = [32, 2048, 4096]; +const NUM_ITEMS: usize = 10000; + +fn fp_rates() -> [(BigRational, &'static str); 2] { + [ + (BigRational::from_frac_u64(1, 10), "10%"), + (BigRational::from_frac_u64(1, 1000), "0.1%"), + ] +} + +fn run_contains_bench(c: &mut Criterion, hasher: &str, query_inserted: bool) { + let query_type = if query_inserted { + "positive" + } else { + "negative" + }; + for item_size in ITEM_SIZES { + for (fp_rate, fp_label) in fp_rates() { + // Create and populate the bloom filter + let mut rng = StdRng::seed_from_u64(42); + let mut bf = + BloomFilter::::with_rate(NonZeroUsize::new(NUM_ITEMS).unwrap(), fp_rate); + let mut inserted_set = HashSet::new(); + + let inserted: Vec> = (0..NUM_ITEMS) + .map(|_| { + let mut item = vec![0u8; item_size]; + rng.fill_bytes(&mut item); + bf.insert(&item); + inserted_set.insert(item.clone()); + item + }) + .collect(); + + // Items to query: inserted ones or guaranteed non-inserted ones + let items = if query_inserted { + inserted + } else { + let mut items = Vec::with_capacity(NUM_ITEMS); + while items.len() < NUM_ITEMS { + let mut item = vec![0u8; item_size]; + rng.fill_bytes(&mut item); + if !inserted_set.contains(&item) { + items.push(item); + } + } + items + }; + + c.bench_function( + &format!( + "{}/hasher={} item_size={} fp_rate={} query={}", + module_path!(), + hasher, + item_size, + fp_label, + query_type + ), + |b| { + let mut idx = 0; + b.iter(|| { + let result = bf.contains(black_box(&items[idx])); + idx = (idx + 1) % items.len(); + result + }); + }, + ); + } + } +} + +fn benchmark_contains(c: &mut Criterion) { + run_contains_bench::(c, "sha256", true); + run_contains_bench::(c, "sha256", false); + run_contains_bench::(c, "blake3", true); + run_contains_bench::(c, "blake3", false); +} + +criterion_group!(benches, benchmark_contains); diff --git a/cryptography/src/bloomfilter/benches/insert.rs b/cryptography/src/bloomfilter/benches/insert.rs new file mode 100644 index 0000000000..124bd78779 --- /dev/null +++ b/cryptography/src/bloomfilter/benches/insert.rs @@ -0,0 +1,65 @@ +use commonware_cryptography::{blake3::Blake3, sha256::Sha256, BloomFilter, Hasher}; +use commonware_utils::rational::BigRationalExt; +use criterion::{criterion_group, BatchSize, Criterion}; +use num_rational::BigRational; +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use std::num::NonZeroUsize; + +const ITEM_SIZES: [usize; 3] = [32, 2048, 4096]; +const NUM_ITEMS: usize = 10000; + +fn fp_rates() -> [(BigRational, &'static str); 2] { + [ + (BigRational::from_frac_u64(1, 10), "10%"), + (BigRational::from_frac_u64(1, 1000), "0.1%"), + ] +} + +fn run_insert_bench(c: &mut Criterion, hasher: &str) { + for item_size in ITEM_SIZES { + for (fp_rate, fp_label) in fp_rates() { + // Pre-generate items to insert + let mut rng = StdRng::seed_from_u64(42); + let items: Vec> = (0..NUM_ITEMS) + .map(|_| { + let mut item = vec![0u8; item_size]; + rng.fill_bytes(&mut item); + item + }) + .collect(); + + c.bench_function( + &format!( + "{}/hasher={} item_size={} fp_rate={}", + module_path!(), + hasher, + item_size, + fp_label + ), + |b| { + let mut idx = 0; + b.iter_batched( + || { + BloomFilter::::with_rate( + NonZeroUsize::new(NUM_ITEMS).unwrap(), + fp_rate.clone(), + ) + }, + |mut bf| { + bf.insert(&items[idx]); + idx = (idx + 1) % items.len(); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } +} + +fn benchmark_insert(c: &mut Criterion) { + run_insert_bench::(c, "sha256"); + run_insert_bench::(c, "blake3"); +} + +criterion_group!(benches, benchmark_insert); diff --git a/cryptography/src/bloomfilter/conformance.rs b/cryptography/src/bloomfilter/conformance.rs new file mode 100644 index 0000000000..cc1b428461 --- /dev/null +++ b/cryptography/src/bloomfilter/conformance.rs @@ -0,0 +1,66 @@ +//! BloomFilter conformance tests + +use super::{BloomFilter, Sha256}; +use commonware_codec::conformance::CodecConformance; +use commonware_conformance::Conformance; +use commonware_utils::rational::BigRationalExt; +use core::num::NonZeroUsize; +use num_rational::BigRational; + +commonware_conformance::conformance_tests! { + CodecConformance, + RationalOptimalBits => 1024, +} + +/// Conformance test for rational-based optimal_bits and with_rate. +/// Verifies that optimal_bits, optimal_hashers, and with_rate produce stable +/// outputs for various expected_items values and FP rates expressed as rationals. +struct RationalOptimalBits; + +impl Conformance for RationalOptimalBits { + async fn commit(seed: u64) -> Vec { + let mut log = Vec::new(); + + // Use seed to vary expected_items (1 to 1M range) + let expected_items = ((seed % 1_000_000) + 1) as usize; + + // Test FP rates as rationals: 1/10000, 1/1000, 1/100, 1/10 + let fp_rates = [ + BigRational::from_frac_u64(1, 10_000), // 0.01% + BigRational::from_frac_u64(1, 1_000), // 0.1% + BigRational::from_frac_u64(1, 100), // 1% + BigRational::from_frac_u64(1, 10), // 10% + ]; + for fp_rate in &fp_rates { + // Test individual functions + let bits = BloomFilter::::optimal_bits(expected_items, fp_rate); + let hashers = BloomFilter::::optimal_hashers(expected_items, bits); + + log.extend((expected_items as u64).to_be_bytes()); + log.extend((bits as u64).to_be_bytes()); + log.extend(hashers.to_be_bytes()); + + // Test with_rate constructor produces same results + let filter = BloomFilter::::with_rate( + NonZeroUsize::new(expected_items).unwrap(), + fp_rate.clone(), + ); + log.extend((filter.bits().get() as u64).to_be_bytes()); + log.extend(filter.hashers().get().to_be_bytes()); + } + + // Test some boundary values + let boundary_rates = [ + BigRational::from_frac_u64(1, 7_000), // Between 0.01% and 0.1% + BigRational::from_frac_u64(1, 500), // Between 0.1% and 1% + BigRational::from_frac_u64(1, 50), // Between 1% and 10% + BigRational::from_frac_u64(3, 100), // 3% + ]; + for fp_rate in &boundary_rates { + let bits = BloomFilter::::optimal_bits(expected_items, fp_rate); + log.extend((bits as u64).to_be_bytes()); + } + + log + } +} diff --git a/cryptography/src/bloomfilter/mod.rs b/cryptography/src/bloomfilter/mod.rs new file mode 100644 index 0000000000..bf44fb7fe0 --- /dev/null +++ b/cryptography/src/bloomfilter/mod.rs @@ -0,0 +1,619 @@ +//! An implementation of a [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter). + +#[cfg(all(test, feature = "arbitrary"))] +mod conformance; + +use crate::{sha256::Sha256, Hasher}; +use bytes::{Buf, BufMut}; +use commonware_codec::{ + codec::{Read, Write}, + error::Error as CodecError, + EncodeSize, FixedSize, +}; +use commonware_utils::bitmap::BitMap; +use core::{ + marker::PhantomData, + num::{NonZeroU64, NonZeroU8, NonZeroUsize}, +}; +#[cfg(feature = "std")] +use { + commonware_utils::rational::BigRationalExt, + num_rational::BigRational, + num_traits::{One, ToPrimitive, Zero}, +}; + +/// Rational approximation of ln(2) with 6 digits of precision: 14397/20769. +#[cfg(feature = "std")] +const LN2: (u64, u64) = (14397, 20769); + +/// Rational approximation of 1/ln(2) with 6 digits of precision: 29145/20201. +#[cfg(feature = "std")] +const LN2_INV: (u64, u64) = (29145, 20201); + +/// A [Bloom Filter](https://en.wikipedia.org/wiki/Bloom_filter). +/// +/// This implementation uses the Kirsch-Mitzenmacher optimization to derive `k` hash functions +/// from two hash values, which are in turn derived from a single hash digest. This provides +/// efficient hashing for [BloomFilter::insert] and [BloomFilter::contains] operations. +/// +/// # Hasher Selection +/// +/// The `H` type parameter specifies the hash function to use. It defaults to [Sha256]. +/// The hasher's digest must be at least 16 bytes (128 bits) long, this is enforced at +/// compile time. +/// +/// When choosing a hasher, consider: +/// +/// - **Security**: If the bloom filter accepts untrusted input, use a cryptographically +/// secure hash function to prevent attackers from crafting inputs that cause excessive +/// collisions (degrading the filter to always return `true`). +/// +/// - **Determinism**: If the bloom filter must produce consistent results across runs +/// or machines (e.g. for serialization or consensus-critical applications), avoid keyed +/// or randomized hash functions. Both [Sha256] and [Blake3](crate::blake3::Blake3) +/// are deterministic. +/// +/// - **Performance**: Hash function performance varies with the size of items inserted +/// and queried. [Sha256] is faster for smaller items (up to ~2KB), while +/// [Blake3](crate::blake3::Blake3) is faster for larger items (4KB+). +#[derive(Clone, Debug)] +pub struct BloomFilter { + hashers: u8, + bits: BitMap, + _marker: PhantomData, +} + +impl PartialEq for BloomFilter { + fn eq(&self, other: &Self) -> bool { + self.hashers == other.hashers && self.bits == other.bits + } +} + +impl Eq for BloomFilter {} + +impl BloomFilter { + /// Compile-time assertion that the digest is at least 16 bytes. + const _ASSERT_DIGEST_AT_LEAST_16_BYTES: () = assert!( + ::SIZE >= 16, + "digest must be at least 128 bits (16 bytes)" + ); + + /// Creates a new [BloomFilter] with `hashers` hash functions and `bits` bits. + /// + /// The number of bits will be rounded up to the next power of 2. If that would + /// overflow, the maximum power of 2 for the platform (2^63 on 64-bit) is used. + pub fn new(hashers: NonZeroU8, bits: NonZeroUsize) -> Self { + let bits = bits + .get() + .checked_next_power_of_two() + .unwrap_or(1 << (usize::BITS - 1)); + Self { + hashers: hashers.get(), + bits: BitMap::zeroes(bits as u64), + _marker: PhantomData, + } + } + + /// Creates a new [BloomFilter] with optimal parameters for the expected number + /// of items and desired false positive rate. + /// + /// Uses exact rational arithmetic for full determinism across all platforms. + /// + /// # Arguments + /// + /// * `expected_items` - Number of items expected to be inserted + /// * `fp_rate` - False positive rate as a rational (e.g., `BigRational::from_frac_u64(1, 100)` for 1%) + /// + /// # Panics + /// + /// Panics if `fp_rate` is not in (0, 1). + #[cfg(feature = "std")] + pub fn with_rate(expected_items: NonZeroUsize, fp_rate: BigRational) -> Self { + let bits = Self::optimal_bits(expected_items.get(), &fp_rate); + let hashers = Self::optimal_hashers(expected_items.get(), bits); + Self { + hashers, + bits: BitMap::zeroes(bits as u64), + _marker: PhantomData, + } + } + + /// Returns the number of hashers used by the filter. + pub const fn hashers(&self) -> NonZeroU8 { + NonZeroU8::new(self.hashers).expect("hashers is never zero") + } + + /// Returns the number of bits used by the filter. + pub const fn bits(&self) -> NonZeroUsize { + NonZeroUsize::new(self.bits.len() as usize).expect("bits is never zero") + } + + /// Generate `num_hashers` bit indices for a given item. + fn indices(&self, item: &[u8]) -> impl Iterator { + #[allow(path_statements)] + Self::_ASSERT_DIGEST_AT_LEAST_16_BYTES; + + // Extract two 64-bit hash values from the digest of the item + let digest = H::hash(item); + let h1 = u64::from_be_bytes(digest[0..8].try_into().unwrap()); + let mut h2 = u64::from_be_bytes(digest[8..16].try_into().unwrap()); + + // Ensure h2 is odd (non-zero). If h2 were 0, all k hash functions would + // produce the same index (h1), defeating the purpose of multiple hashers. + h2 |= 1; + + // Generate `hashers` hashes using the Kirsch-Mitzenmacher optimization: + // + // `h_i(x) = (h1(x) + i * h2(x)) mod m` + let hashers = self.hashers as u64; + let mask = self.bits.len() - 1; + (0..hashers).map(move |hasher| h1.wrapping_add(hasher.wrapping_mul(h2)) & mask) + } + + /// Inserts an item into the [BloomFilter]. + pub fn insert(&mut self, item: &[u8]) { + let indices = self.indices(item); + for index in indices { + self.bits.set(index, true); + } + } + + /// Checks if an item is possibly in the [BloomFilter]. + /// + /// Returns `true` if the item is probably in the set, and `false` if it is definitely not. + pub fn contains(&self, item: &[u8]) -> bool { + let indices = self.indices(item); + for index in indices { + if !self.bits.get(index) { + return false; + } + } + true + } + + /// Estimates the current false positive probability. + /// + /// This approximates the false positive rate as `f^k` where `f` is the fill ratio + /// (proportion of bits set to 1) and `k` is the number of hash functions. + /// + /// Returns a [`BigRational`] for exact representation and cross-platform determinism. + #[cfg(feature = "std")] + pub fn estimated_false_positive_rate(&self) -> BigRational { + let ones = self.bits.count_ones(); + let len = self.bits.len(); + let fill_ratio = BigRational::new(ones.into(), len.into()); + fill_ratio.pow(self.hashers as i32) + } + + /// Estimates the number of items that have been inserted. + /// + /// Uses the formula `n = -(m/k) * ln(1 - x/m)` where `m` is the number of bits, + /// `k` is the number of hash functions, and `x` is the number of bits set to 1. + /// + /// Returns a [`BigRational`] using `log2_floor` for the logarithm computation. + #[cfg(feature = "std")] + pub fn estimated_count(&self) -> BigRational { + let m = self.bits.len(); + let x = self.bits.count_ones(); + let k = self.hashers as u64; + if x >= m { + return BigRational::from_usize(usize::MAX); + } + + // ln(1 - x/m) = log2(1 - x/m) * ln(2) + let one_minus_fill = BigRational::new((m - x).into(), m.into()); + let log2_val = one_minus_fill.log2_floor(16); + let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1); + let ln_result = &log2_val * &ln2; + + // n = -(m/k) * ln(1 - x/m) + let m_over_k = BigRational::new(m.into(), k.into()); + -m_over_k * ln_result + } + + /// Calculates the optimal number of hash functions for a given capacity and bit count. + /// + /// Uses [`BigRational`] for determinism. The result is clamped to [1, 16] since + /// beyond ~10-12 hashes provides negligible improvement while increasing CPU cost. + #[cfg(feature = "std")] + pub fn optimal_hashers(expected_items: usize, bits: usize) -> u8 { + if expected_items == 0 { + return 1; + } + + // k = (m/n) * ln(2) + let ln2 = BigRational::from_frac_u64(LN2.0, LN2.1); + let k_ratio = BigRational::from_usize(bits) * ln2 / BigRational::from_usize(expected_items); + k_ratio.to_integer().to_u8().unwrap_or(16).clamp(1, 16) + } + + /// Calculates the optimal number of bits for a given capacity and false positive rate. + /// + /// Uses exact rational arithmetic for full determinism across all platforms. + /// The result is rounded up to the next power of 2. If that would overflow, the maximum + /// power of 2 for the platform (2^63 on 64-bit) is used. + /// + /// Formula: m = -n * log2(p) / ln(2) + /// + /// # Panics + /// + /// Panics if `fp_rate` is not in (0, 1). + #[cfg(feature = "std")] + pub fn optimal_bits(expected_items: usize, fp_rate: &BigRational) -> usize { + assert!( + fp_rate > &BigRational::zero() && fp_rate < &BigRational::one(), + "false positive rate must be in (0, 1)" + ); + + // log2(p) is negative for p < 1. Use floor to get a more negative value, + // which results in more bits (conservative choice to not exceed target FP rate). + let log2_p = fp_rate.log2_floor(16); + + // m = -n * log2(p) / ln(2) = -n * log2(p) * (1/ln(2)) + // Since log2(p) < 0 for p < 1, -log2(p) > 0 + let n = BigRational::from_usize(expected_items); + let ln2_inv = BigRational::from_frac_u64(LN2_INV.0, LN2_INV.1); + let bits_rational = -(&n * &log2_p * &ln2_inv); + + let raw = bits_rational.ceil_to_u128().unwrap_or(1) as usize; + raw.max(1) + .checked_next_power_of_two() + .unwrap_or(1 << (usize::BITS - 1)) + } +} + +impl Write for BloomFilter { + fn write(&self, buf: &mut impl BufMut) { + self.hashers.write(buf); + self.bits.write(buf); + } +} + +impl Read for BloomFilter { + // The number of hashers and the number of bits that the bitmap must have. + type Cfg = (NonZeroU8, NonZeroU64); + + fn read_cfg( + buf: &mut impl Buf, + (hashers_cfg, bits_cfg): &Self::Cfg, + ) -> Result { + if !bits_cfg.get().is_power_of_two() { + return Err(CodecError::Invalid( + "BloomFilter", + "bits must be a power of 2", + )); + } + let hashers = u8::read_cfg(buf, &())?; + if hashers != hashers_cfg.get() { + return Err(CodecError::Invalid( + "BloomFilter", + "hashers doesn't match config", + )); + } + let bits = BitMap::read_cfg(buf, &bits_cfg.get())?; + if bits.len() != bits_cfg.get() { + return Err(CodecError::Invalid( + "BloomFilter", + "bitmap length doesn't match config", + )); + } + Ok(Self { + hashers, + bits, + _marker: PhantomData, + }) + } +} + +impl EncodeSize for BloomFilter { + fn encode_size(&self) -> usize { + self.hashers.encode_size() + self.bits.encode_size() + } +} + +#[cfg(feature = "arbitrary")] +impl arbitrary::Arbitrary<'_> for BloomFilter { + fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { + // Ensure at least 1 hasher + let hashers = u8::arbitrary(u)?.max(1); + // Generate u64 in u16 range to avoid OOM, then round to power of two + let bits_len = u.int_in_range(0..=u16::MAX as u64)?.next_power_of_two(); + let mut bits = BitMap::with_capacity(bits_len); + for _ in 0..bits_len { + bits.push(u.arbitrary::()?); + } + Ok(Self { + hashers, + bits, + _marker: PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use commonware_codec::{Decode, Encode}; + use commonware_utils::{NZUsize, NZU64, NZU8}; + + #[test] + fn test_insert_and_contains() { + let mut bf = BloomFilter::::new(NZU8!(10), NZUsize!(1000)); + let item1 = b"hello"; + let item2 = b"world"; + let item3 = b"bloomfilter"; + + bf.insert(item1); + bf.insert(item2); + + assert!(bf.contains(item1)); + assert!(bf.contains(item2)); + assert!(!bf.contains(item3)); + } + + #[test] + fn test_empty() { + let bf = BloomFilter::::new(NZU8!(5), NZUsize!(100)); + assert!(!bf.contains(b"anything")); + } + + #[test] + fn test_false_positives() { + let mut bf = BloomFilter::::new(NZU8!(10), NZUsize!(100)); + for i in 0..10usize { + bf.insert(&i.to_be_bytes()); + } + + // Check for inserted items + for i in 0..10usize { + assert!(bf.contains(&i.to_be_bytes())); + } + + // Check for non-inserted items and count false positives + let mut false_positives = 0; + for i in 100..1100usize { + if bf.contains(&i.to_be_bytes()) { + false_positives += 1; + } + } + + // A small bloom filter with many items will have some false positives. + // The exact number is probabilistic, but it should not be zero and not all should be FPs. + assert!(false_positives > 0); + assert!(false_positives < 1000); + } + + #[test] + fn test_codec_roundtrip() { + let mut bf = BloomFilter::::new(NZU8!(5), NZUsize!(128)); + bf.insert(b"test1"); + bf.insert(b"test2"); + + let cfg = (NZU8!(5), NZU64!(128)); + + let encoded = bf.encode(); + let decoded = BloomFilter::::decode_cfg(encoded, &cfg).unwrap(); + + assert_eq!(bf, decoded); + } + + #[test] + fn test_codec_empty() { + let bf = BloomFilter::::new(NZU8!(4), NZUsize!(128)); + let cfg = (NZU8!(4), NZU64!(128)); + let encoded = bf.encode(); + let decoded = BloomFilter::::decode_cfg(encoded, &cfg).unwrap(); + assert_eq!(bf, decoded); + } + + #[test] + fn test_codec_with_invalid_hashers() { + let mut bf = BloomFilter::::new(NZU8!(5), NZUsize!(128)); + bf.insert(b"test1"); + let encoded = bf.encode(); + + // Too large + let cfg = (NZU8!(10), NZU64!(128)); + let decoded = BloomFilter::::decode_cfg(encoded.clone(), &cfg); + assert!(matches!( + decoded, + Err(CodecError::Invalid( + "BloomFilter", + "hashers doesn't match config" + )) + )); + + // Too small + let cfg = (NZU8!(4), NZU64!(128)); + let decoded = BloomFilter::::decode_cfg(encoded, &cfg); + assert!(matches!( + decoded, + Err(CodecError::Invalid( + "BloomFilter", + "hashers doesn't match config" + )) + )); + } + + #[test] + fn test_codec_with_invalid_bits() { + let mut bf = BloomFilter::::new(NZU8!(5), NZUsize!(128)); + bf.insert(b"test1"); + let encoded = bf.encode(); + + // Wrong bit count + let cfg = (NZU8!(5), NZU64!(64)); + let result = BloomFilter::::decode_cfg(encoded.clone(), &cfg); + assert!(matches!(result, Err(CodecError::InvalidLength(128)))); + + let cfg = (NZU8!(5), NZU64!(256)); + let result = BloomFilter::::decode_cfg(encoded.clone(), &cfg); + assert!(matches!( + result, + Err(CodecError::Invalid( + "BloomFilter", + "bitmap length doesn't match config" + )) + )); + + // Non-power-of-2 bits + let cfg = (NZU8!(5), NZU64!(100)); + let result = BloomFilter::::decode_cfg(encoded, &cfg); + assert!(matches!( + result, + Err(CodecError::Invalid( + "BloomFilter", + "bits must be a power of 2" + )) + )); + } + + #[test] + fn test_statistics() { + let mut bf = BloomFilter::::new(NZU8!(7), NZUsize!(1024)); + + // Empty filter should have 0 estimated count and FP rate + assert_eq!(bf.estimated_count(), BigRational::zero()); + assert_eq!(bf.estimated_false_positive_rate(), BigRational::zero()); + + // Insert some items + for i in 0..100usize { + bf.insert(&i.to_be_bytes()); + } + + // Estimated count should be reasonably close to 100 + let estimated = bf.estimated_count(); + let lower = BigRational::from_usize(75); + let upper = BigRational::from_usize(125); + assert!(estimated > lower && estimated < upper); + + // FP rate should be non-zero after insertions + assert!(bf.estimated_false_positive_rate() > BigRational::zero()); + assert!(bf.estimated_false_positive_rate() < BigRational::one()); + } + + #[test] + fn test_with_rate() { + // Create a filter for 1000 items with 1% false positive rate + let fp_rate = BigRational::from_frac_u64(1, 100); + let mut bf = BloomFilter::::with_rate(NZUsize!(1000), fp_rate.clone()); + + // Verify getters return expected values + let expected_bits = BloomFilter::::optimal_bits(1000, &fp_rate); + let expected_hashers = BloomFilter::::optimal_hashers(1000, expected_bits); + assert_eq!(bf.bits().get(), expected_bits); + assert_eq!(bf.hashers().get(), expected_hashers); + + // Insert 1000 items + for i in 0..1000usize { + bf.insert(&i.to_be_bytes()); + } + + // All inserted items should be found + for i in 0..1000usize { + assert!(bf.contains(&i.to_be_bytes())); + } + + // Count false positives on non-inserted items + let mut false_positives = 0; + for i in 1000..2000usize { + if bf.contains(&i.to_be_bytes()) { + false_positives += 1; + } + } + + // With 1% target FP rate, we expect around 10 false positives out of 1000 + // Allow some variance (should be well under 2%) + assert!(false_positives < 20); + } + + #[test] + fn test_optimal_hashers() { + // For 1000 items in 10000 bits, optimal k = (10000/1000) * ln(2) = 6.93 + // Integer math truncates to 6 + let k = BloomFilter::::optimal_hashers(1000, 10000); + assert_eq!(k, 6); + + // For 100 items in 1000 bits, optimal k = (1000/100) * ln(2) = 6.93 + // Integer math truncates to 6 + let k = BloomFilter::::optimal_hashers(100, 1000); + assert_eq!(k, 6); + + // Edge case: very few bits per item, clamped to 1 + let k = BloomFilter::::optimal_hashers(1000, 100); + assert_eq!(k, 1); + + // Edge case: many bits per item, clamped to 16 + let k = BloomFilter::::optimal_hashers(100, 100000); + assert_eq!(k, 16); + + // Edge case: zero items returns 1 + let k = BloomFilter::::optimal_hashers(0, 1000); + assert_eq!(k, 1); + + // Edge case: extreme values that would overflow (n << 16 wraps to 0 for n >= 2^48) + // Should not panic, should return clamped value + let k = BloomFilter::::optimal_hashers(1 << 48, 1000); + assert_eq!(k, 1); + let k = BloomFilter::::optimal_hashers(usize::MAX, usize::MAX); + assert!((1..=16).contains(&k)); + } + + #[test] + fn test_optimal_bits() { + // For 1000 items with 1% FP rate + // Formula: m = -n * ln(p) / (ln(2))^2 = -1000 * ln(0.01) / 0.4804 = 9585 + // Rounded to next power of 2 = 16384 + let fp_1pct = BigRational::from_frac_u64(1, 100); + let bits = BloomFilter::::optimal_bits(1000, &fp_1pct); + assert_eq!(bits, 16384); + assert!(bits.is_power_of_two()); + + // For 10000 items with 0.001% FP rate (need significantly more bits) + // Formula: m = -10000 * ln(0.00001) / 0.4804 = 239627 + // Rounded to next power of 2 = 262144 + let fp_001pct = BigRational::from_frac_u64(1, 100_000); + let bits_lower_fp = BloomFilter::::optimal_bits(10000, &fp_001pct); + assert_eq!(bits_lower_fp, 262144); + assert!(bits_lower_fp.is_power_of_two()); + } + + #[test] + fn test_bits_extreme_values() { + let fp_001pct = BigRational::from_frac_u64(1, 10_000); + let fp_1pct = BigRational::from_frac_u64(1, 100); + + // Very large expected_items + let bits = BloomFilter::::optimal_bits(usize::MAX / 2, &fp_001pct); + assert!(bits.is_power_of_two()); + assert!(bits > 0); + + // Large but reasonable values + let bits = BloomFilter::::optimal_bits(1_000_000_000, &fp_001pct); + assert!(bits.is_power_of_two()); + + // Zero items + let bits = BloomFilter::::optimal_bits(0, &fp_1pct); + assert!(bits.is_power_of_two()); + assert_eq!(bits, 1); // 0 * bpe rounds up to 1 + } + + #[test] + fn test_with_rate_deterministic() { + let fp_rate = BigRational::from_frac_u64(1, 100); + let bf1 = BloomFilter::::with_rate(NZUsize!(1000), fp_rate.clone()); + let bf2 = BloomFilter::::with_rate(NZUsize!(1000), fp_rate); + assert_eq!(bf1.bits(), bf2.bits()); + assert_eq!(bf1.hashers(), bf2.hashers()); + } + + #[test] + fn test_optimal_bits_matches_formula() { + // For 1000 items at 1% FP rate + // m = -1000 * log2(0.01) / ln(2) = 9585 + // Rounded to power of 2 = 16384 + let fp_rate = BigRational::from_frac_u64(1, 100); + let bits = BloomFilter::::optimal_bits(1000, &fp_rate); + assert_eq!(bits, 16384); + } +} diff --git a/utils/src/rational.rs b/utils/src/rational.rs index 888b8744ba..5764eec1cd 100644 --- a/utils/src/rational.rs +++ b/utils/src/rational.rs @@ -1,10 +1,88 @@ //! Utilities for working with `num_rational::BigRational`. -use num_bigint::BigInt; +use num_bigint::{BigInt, BigUint}; use num_integer::Integer; use num_rational::BigRational; use num_traits::{One, ToPrimitive, Zero}; +/// Computes log2 of a rational number with specified binary precision. +/// +/// Returns `(numerator, has_remainder)` where the result is `numerator / 2^binary_digits` +/// and `has_remainder` indicates whether there's additional precision beyond what was computed. +fn log2(numer: BigUint, denom: BigUint, binary_digits: usize) -> (BigInt, bool) { + // Compute the integer part of log2(numer/denom) by comparing bit lengths. + // Since log2(numer/denom) = log2(numer) - log2(denom), and bits() gives us + // floor(log2(x)) + 1, we can compute the integer part directly. + let numer_bits = numer.bits(); + let denom_bits = denom.bits(); + let mut integer_part = numer_bits as i128 - denom_bits as i128; + + // Align the most significant bits of numerator and denominator to bring + // the ratio into the range [1, 2). By shifting both values to have the same bit + // length, we normalize the ratio in a single operation. + let mut normalized_numer = numer; + if denom_bits > numer_bits { + normalized_numer <<= denom_bits - numer_bits; + } + let mut normalized_denom = denom; + if numer_bits > denom_bits { + normalized_denom <<= numer_bits - denom_bits; + } + + // After alignment, we may need one additional shift to ensure normalized value is in [1, 2). + if normalized_numer < normalized_denom { + normalized_numer <<= 1; + integer_part -= 1; + } + + // Handle the special case where the value is exactly a power of 2. + // In this case, log2(x) is exact and has no fractional component. + if normalized_numer == normalized_denom { + return (BigInt::from(integer_part) << binary_digits, false); + } + + // Extract binary fractional digits using the square-and-compare method. + // At this point, normalized is in (1, 2), so log2(normalized) is in (0, 1). + // We use integer-only arithmetic to avoid BigRational division overhead: + // Instead of squaring the rational and comparing to 2, we square the numerator + // and denominator separately and check if numer^2 >= 2 * denom^2. + let mut fractional_bits = BigInt::zero(); + let one = BigInt::one(); + + for _ in 0..binary_digits { + // Square both numerator and denominator to shift the next binary digit into position. + let numer_squared = &normalized_numer * &normalized_numer; + let denom_squared = &normalized_denom * &normalized_denom; + + // Left-shift the fractional bits accumulator to make room for the new bit. + fractional_bits <<= 1; + + // If squared value >= 2, the next binary digit is 1. + // We renormalize by dividing by 2, which is equivalent to multiplying the denominator by 2. + let two_denom_squared = &denom_squared << 1; + if numer_squared >= two_denom_squared { + fractional_bits |= &one; + normalized_numer = numer_squared; + normalized_denom = two_denom_squared; + } else { + normalized_numer = numer_squared; + normalized_denom = denom_squared; + } + } + + // Combine integer and fractional parts. + // We return a single rational number with denominator 2^binary_digits. + // By left-shifting the integer part, we convert it to the same "units" as fractional_bits, + // allowing us to add them: numerator = (integer_part * 2^binary_digits) + fractional_bits. + // This represents: integer_part + fractional_bits / (2^binary_digits) + let numerator = (BigInt::from(integer_part) << binary_digits) + fractional_bits; + + // has_remainder is true if there's any leftover mass in the normalized value + // after extracting all digits. This happens when normalized_numer > normalized_denom. + let has_remainder = normalized_numer > normalized_denom; + (numerator, has_remainder) +} + /// Extension trait adding convenience constructors for [`BigRational`]. pub trait BigRationalExt { /// Creates a [`BigRational`] from a `u64` numerator with denominator `1`. @@ -62,6 +140,29 @@ pub trait BigRationalExt { /// assert!(result <= BigRational::from_u64(2)); /// ``` fn log2_ceil(&self, binary_digits: usize) -> BigRational; + + /// Computes the floor of log2 of this rational number with specified binary precision. + /// + /// Returns log2(x) rounded down to the nearest value representable with `binary_digits` + /// fractional bits in binary representation. + /// + /// # Panics + /// + /// Panics if the rational number is non-positive. + /// + /// # Examples + /// + /// ``` + /// use num_rational::BigRational; + /// use commonware_utils::rational::BigRationalExt; + /// + /// let x = BigRational::from_frac_u64(3, 1); // 3 + /// let result = x.log2_floor(4); + /// // log2(3) ≈ 1.585, the algorithm computes a floor approximation + /// assert!(result >= BigRational::from_u64(1)); + /// assert!(result <= BigRational::from_u64(2)); + /// ``` + fn log2_floor(&self, binary_digits: usize) -> BigRational; } impl BigRationalExt for BigRational { @@ -112,89 +213,25 @@ impl BigRationalExt for BigRational { panic!("log2 undefined for non-positive numbers"); } - // Step 1: Extract numerator and denominator as unsigned integers for efficient computation. let numer = self.numer().to_biguint().expect("positive"); let denom = self.denom().to_biguint().expect("positive"); - - // Step 2: Compute the integer part of log2(numer/denom) by comparing bit lengths. - // Since log2(numer/denom) = log2(numer) - log2(denom), and bits() gives us - // floor(log2(x)) + 1, we can compute the integer part directly. - let numer_bits = numer.bits(); - let denom_bits = denom.bits(); - let mut integer_part = numer_bits as i128 - denom_bits as i128; - - // Step 3: Align the most significant bits of numerator and denominator to bring - // the ratio into the range [1, 2). By shifting both values to have the same bit - // length, we normalize the ratio in a single operation. - let mut normalized_numer = numer; - if denom_bits > numer_bits { - normalized_numer <<= denom_bits - numer_bits; - } - let mut normalized_denom = denom; - if numer_bits > denom_bits { - normalized_denom <<= numer_bits - denom_bits; - } - - // After alignment, we may need one additional shift to ensure normalized value is in [1, 2). - if normalized_numer < normalized_denom { - normalized_numer <<= 1; - integer_part -= 1; - } - assert!( - normalized_numer >= normalized_denom && normalized_numer < (&normalized_denom << 1) - ); - - // Step 4: Handle the special case where the value is exactly a power of 2. - // In this case, log2(x) is exact and has no fractional component. - if normalized_numer == normalized_denom { - let numerator = BigInt::from(integer_part) << binary_digits; - let denominator = BigInt::one() << binary_digits; - return Self::new(numerator, denominator); - } - - // Step 5: Extract binary fractional digits using the square-and-compare method. - // At this point, normalized is in (1, 2), so log2(normalized) is in (0, 1). - // We use integer-only arithmetic to avoid BigRational division overhead: - // Instead of squaring the rational and comparing to 2, we square the numerator - // and denominator separately and check if numer^2 >= 2 * denom^2. - let mut fractional_bits = BigInt::zero(); - let one = BigInt::one(); - - for _ in 0..binary_digits { - // Square both numerator and denominator to shift the next binary digit into position. - let numer_squared = &normalized_numer * &normalized_numer; - let denom_squared = &normalized_denom * &normalized_denom; - - // Left-shift the fractional bits accumulator to make room for the new bit. - fractional_bits <<= 1; - - // If squared value >= 2, the next binary digit is 1. - // We renormalize by dividing by 2, which is equivalent to multiplying the denominator by 2. - let two_denom_squared = &denom_squared << 1; - if numer_squared >= two_denom_squared { - fractional_bits |= &one; - normalized_numer = numer_squared; - normalized_denom = two_denom_squared; - } else { - normalized_numer = numer_squared; - normalized_denom = denom_squared; - } + let (mut numerator, has_remainder) = log2(numer, denom, binary_digits); + if has_remainder { + numerator += 1; } + let denominator = BigInt::one() << binary_digits; + Self::new(numerator, denominator) + } - // Step 6: Combine integer and fractional parts, then apply ceiling operation. - // We need to return a single rational number with denominator 2^binary_digits. - // By left-shifting the integer part, we convert it to the same "units" as fractional_bits, - // allowing us to add them: numerator = (integer_part * 2^binary_digits) + fractional_bits. - // This represents: integer_part + fractional_bits / (2^binary_digits) - let mut numerator = (BigInt::from(integer_part) << binary_digits) + fractional_bits; - - // If there's any leftover mass in the normalized value after extracting all digits, - // we need to round up (ceiling operation). This happens when normalized_numer > normalized_denom. - if normalized_numer > normalized_denom { - numerator += &one; + fn log2_floor(&self, binary_digits: usize) -> BigRational { + if self <= &Self::zero() { + panic!("log2 undefined for non-positive numbers"); } - let denominator = one << binary_digits; + let numer = self.numer().to_biguint().expect("positive"); + let denom = self.denom().to_biguint().expect("positive"); + let (numerator, _) = log2(numer, denom, binary_digits); + let denominator = BigInt::one() << binary_digits; Self::new(numerator, denominator) } } @@ -427,4 +464,214 @@ mod tests { let x = BigRational::new(num, den); assert_eq!(x.log2_ceil(8), BigRational::from_frac_u64(23, 256)); } + + #[test] + #[should_panic(expected = "log2 undefined for non-positive numbers")] + fn log2_floor_negative_panics() { + ::from_i64(-1) + .unwrap() + .log2_floor(8); + } + + #[test] + fn log2_floor_exact_powers_of_two() { + // Test exact powers of 2: log2(2^n) = n (same as ceil) + let value = BigRational::from_u64(1); // 2^0 + assert_eq!(value.log2_floor(4), BigRational::from_u64(0)); + + let value = BigRational::from_u64(2); // 2^1 + assert_eq!(value.log2_floor(4), BigRational::from_u64(1)); + + let value = BigRational::from_u64(8); // 2^3 + assert_eq!(value.log2_floor(4), BigRational::from_u64(3)); + + let value = BigRational::from_u64(1024); // 2^10 + assert_eq!(value.log2_floor(4), BigRational::from_u64(10)); + } + + #[test] + fn log2_floor_fractional_powers_of_two() { + // Test fractional powers of 2: log2(1/2) = -1, log2(1/4) = -2 (same as ceil) + let value = BigRational::from_frac_u64(1, 2); // 2^(-1) + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_integer(BigInt::from(-1))); + + let value = BigRational::from_frac_u64(1, 4); // 2^(-2) + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_integer(BigInt::from(-2))); + + // log2(3/8) ≈ -1.415, floor(-1.415 * 16) / 16 = -23/16 + let value = BigRational::from_frac_u64(3, 8); + let result = value.log2_floor(4); + assert_eq!( + result, + BigRational::new(BigInt::from(-23), BigInt::from(16)) + ); + } + + #[test] + fn log2_floor_simple_values() { + // log2(3) ≈ 1.585, with binary_digits=0 we get floor(1.585) = 1 + let value = BigRational::from_u64(3); + let result = value.log2_floor(0); + assert_eq!(result, BigRational::from_u64(1)); + + // log2(5) ≈ 2.322, with binary_digits=0 we get floor(2.322) = 2 + let value = BigRational::from_u64(5); + let result = value.log2_floor(0); + assert_eq!(result, BigRational::from_u64(2)); + + // With 4 bits precision + // log2(3) ≈ 1.585, floor(1.585 * 16) / 16 = 25/16 + let value = BigRational::from_u64(3); + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_frac_u64(25, 16)); + } + + #[test] + fn log2_floor_rational_values() { + // log2(3/2) ≈ 0.585, floor(0.585 * 16) / 16 = 9/16 + let value = BigRational::from_frac_u64(3, 2); + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_frac_u64(9, 16)); + + // log2(7/4) ≈ 0.807, floor(0.807 * 16) / 16 = 12/16 + let value = BigRational::from_frac_u64(7, 4); + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_frac_u64(12, 16)); + } + + #[test] + fn log2_floor_different_precisions() { + let value = BigRational::from_u64(3); + + // Test different precisions give reasonable results + let result0 = value.log2_floor(0); + let result1 = value.log2_floor(1); + let result4 = value.log2_floor(4); + let result8 = value.log2_floor(8); + + assert_eq!(result0, BigRational::from_u64(1)); + assert_eq!(result1, BigRational::from_frac_u64(3, 2)); + assert_eq!(result4, BigRational::from_frac_u64(25, 16)); + assert_eq!( + result8, + BigRational::new(BigInt::from(405), BigInt::from(256)) + ); + } + + #[test] + fn log2_floor_large_values() { + // log2(1000) ≈ 9.966, floor = 9 + let value = BigRational::from_u64(1000); + let result = value.log2_floor(4); + assert_eq!(result, BigRational::from_frac_u64(159, 16)); + } + + #[test] + fn log2_floor_very_small_values() { + // log2(1/1000) ≈ -9.966 + let value = BigRational::from_frac_u64(1, 1000); + let result = value.log2_floor(4); + assert_eq!( + result, + BigRational::new(BigInt::from(-160), BigInt::from(16)) + ); + } + + #[test] + fn log2_floor_edge_cases() { + // -- Just above a power of two + // log2(17/16) ≈ 0.087462, k=8 → floor(0.087462 * 256) = 22 → 22/256 + let x = BigRational::from_frac_u64(17, 16); + assert_eq!(x.log2_floor(8), BigRational::from_frac_u64(22, 256)); + + // log2(129/128) ≈ 0.011227, k=8 → floor(0.011227 * 256) = 2 → 2/256 + let x = BigRational::from_frac_u64(129, 128); + assert_eq!(x.log2_floor(8), BigRational::from_frac_u64(2, 256)); + + // log2(33/32) ≈ 0.044394, k=10 → floor(0.044394 * 1024) = 45 → 45/1024 + let x = BigRational::from_frac_u64(33, 32); + assert_eq!(x.log2_floor(10), BigRational::from_frac_u64(45, 1024)); + + // -- Just below a power of two (negative, but tiny in magnitude) + // log2(255/256) ≈ -0.00565, k=8 → floor(-0.00565 * 256) = -2 → -2/256 + let x = BigRational::from_frac_u64(255, 256); + assert_eq!( + x.log2_floor(8), + BigRational::new((-2).into(), 256u32.into()) + ); + + // log2(1023/1024) ≈ -0.00141, k=9 → floor(-0.00141 * 512) = -1 → -1/512 + let x = BigRational::from_frac_u64(1023, 1024); + assert_eq!( + x.log2_floor(9), + BigRational::new((-1).into(), 512u32.into()) + ); + + // -- k = 0 (integer floor of log2) + // log2(3/2) ≈ 0.585 ⇒ floor = 0 + let x = BigRational::from_frac_u64(3, 2); + assert_eq!(x.log2_floor(0), BigRational::from_integer(0.into())); + + // log2(3/4) ≈ -0.415 ⇒ floor = -1 + let x = BigRational::from_frac_u64(3, 4); + assert_eq!(x.log2_floor(0), BigRational::from_integer((-1).into())); + + // -- x < 1 with fractional bits (negative dyadic output) + // log2(3/4) ≈ -0.415, k=4 → floor(-0.415 * 16) = -7 → -7/16 + let x = BigRational::from_frac_u64(3, 4); + assert_eq!(x.log2_floor(4), BigRational::new((-7).into(), 16u32.into())); + + // -- Monotonic with k: increasing k refines the dyadic downwards + // For 257/256: k=8 → floor(0.00563*256) = 1 → 1/256 + // k=9 → floor(0.00563*512) = 2 → 2/512 + let x = BigRational::from_frac_u64(257, 256); + assert_eq!(x.log2_floor(8), BigRational::new(1.into(), 256u32.into())); + assert_eq!(x.log2_floor(9), BigRational::new(2.into(), 512u32.into())); + + // -- Scale invariance (multiply numerator and denominator by same factor, result unchanged) + // (17/16) * (2^30 / 2^30) has the same log2, the dyadic result should match 22/256 at k=8. + let num = BigInt::from(17u32) << 30; + let den = BigInt::from(16u32) << 30; + let x = BigRational::new(num, den); + assert_eq!(x.log2_floor(8), BigRational::from_frac_u64(22, 256)); + } + + #[test] + fn log2_floor_ceil_relationship() { + // For any non-power-of-2 value, floor < ceil + // For exact powers of 2, floor == ceil + let test_values = [ + BigRational::from_u64(3), + BigRational::from_u64(5), + BigRational::from_frac_u64(3, 2), + BigRational::from_frac_u64(7, 8), + BigRational::from_frac_u64(1, 3), + ]; + + for value in test_values { + let floor = value.log2_floor(8); + let ceil = value.log2_ceil(8); + assert!( + floor < ceil, + "floor should be less than ceil for non-power-of-2" + ); + } + + // Exact powers of 2 + let powers_of_two = [ + BigRational::from_u64(1), + BigRational::from_u64(2), + BigRational::from_u64(4), + BigRational::from_frac_u64(1, 2), + BigRational::from_frac_u64(1, 4), + ]; + + for value in powers_of_two { + let floor = value.log2_floor(8); + let ceil = value.log2_ceil(8); + assert_eq!(floor, ceil, "floor should equal ceil for power of 2"); + } + } }