Skip to content

Commit

Permalink
Resolved issue for high precision MLE estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaCappelletti94 committed Aug 15, 2024
1 parent bc22970 commit 13b819b
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 40 deletions.
152 changes: 115 additions & 37 deletions src/core/src/sketch/hyperloglog/estimators.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,91 @@
use std::cmp;
use core::{
cmp,
ops::{Add, AddAssign, Shl, Sub, SubAssign},
};

pub type CounterType = u8;

pub fn counts(registers: &[CounterType], q: usize) -> Vec<u16> {
let mut counts = vec![0; q + 2];
/// Trait for types that can be used as multiplicity integers.
pub trait MulteplicityInteger:
Shl<usize, Output = Self>
+ Copy
+ AddAssign
+ SubAssign
+ Eq
+ Sub<Self, Output = Self>
+ Add<Self, Output = Self>
+ TryFrom<usize>
+ Ord
{
/// The zero value.
const ZERO: Self;
/// The one value.
const ONE: Self;

/// Convert the value to a `f64`.
fn to_f64(self) -> f64;
}

macro_rules! impl_multeplicity_integer {
($($t:ty),*) => {
$(
impl MulteplicityInteger for $t {
const ONE: Self = 1;
const ZERO: Self = 0;

fn to_f64(self) -> f64 {
self as f64
}
}
)*
};
}

impl_multeplicity_integer!(u8, u16, u32);

pub fn counts<M: MulteplicityInteger>(registers: &[CounterType], q: usize) -> Vec<M> {
let mut counts = vec![M::ZERO; q + 2];

for k in registers {
counts[*k as usize] += 1;
counts[*k as usize] += M::ONE;
}

counts
}

#[allow(clippy::many_single_char_names)]
pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 {
let m = 1 << p;
pub fn mle<M: MulteplicityInteger>(counts: &[M], p: usize, q: usize, relerr: f64) -> f64 {
let m: M = M::ONE << p;

// If all of the registers are equal to zero, then we return zero.
if counts[0] == m {
return 0.0;
}

// If all of the registers are equal to the maximal possible value
// that a register may have, then we return infinity.
if counts[q + 1] == m {
return f64::INFINITY;
}

let (k_min, _) = counts.iter().enumerate().find(|(_, v)| **v != 0).unwrap();
let (k_min, _) = counts
.iter()
.enumerate()
.find(|(_, v)| **v != M::ZERO)
.unwrap();
let k_min_prime = cmp::max(1, k_min);

let (k_max, _) = counts
.iter()
.enumerate()
.rev()
.find(|(_, v)| **v != 0)
.find(|(_, v)| **v != M::ZERO)
.unwrap();
let k_max_prime = cmp::min(q, k_max);

let mut z = 0.;
for i in num_iter::range_step_inclusive(k_max_prime as i32, k_min_prime as i32, -1) {
z = 0.5 * z + counts[i as usize] as f64;
z = 0.5 * z + counts[i as usize].to_f64();
}

// ldexp(x, i) = x * (2 ** i)
Expand All @@ -44,9 +97,9 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 {
}

let mut g_prev = 0.;
let a = z + (counts[0] as f64);
let b = z + (counts[q + 1] as f64) * 2f64.powi(-(q as i32));
let m_prime = (m - counts[0]) as f64;
let a = z + (counts[0].to_f64());
let b = z + (counts[q + 1].to_f64()) * 2f64.powi(-(q as i32));
let m_prime = (m - counts[0]).to_f64();

let mut x = if b <= 1.5 * a {
// weak lower bound (47)
Expand All @@ -57,7 +110,7 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 {
};

let mut delta_x = x;
let del = relerr / (m as f64).sqrt();
let del = relerr / m.to_f64().sqrt();
while delta_x > x * del {
// secant method iteration

Expand All @@ -78,13 +131,13 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 {
}

// compare (53)
let mut g = c_prime as f64 * h;
let mut g = c_prime.to_f64() * h;

for k in num_iter::range_step_inclusive(k_max_prime as i32 - 1, k_min_prime as i32, -1) {
let h_prime = 1. - h;
// Calculate h(x/2^k), see (56), at this point x_prime = x / (2^(k+2))
h = (x_prime + h * h_prime) / (x_prime + h_prime);
g += counts[k as usize] as f64 * h;
g += counts[k as usize].to_f64() * h;
x_prime += x_prime;
}

Expand All @@ -100,7 +153,7 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 {
g_prev = g
}

m as f64 * x
m.to_f64() * x
}

/// Calculate the joint maximum likelihood of A and B.
Expand All @@ -111,57 +164,82 @@ pub fn joint_mle(
k2: &[CounterType],
p: usize,
q: usize,
) -> (usize, usize, usize) {
let mut c1 = vec![0; q + 2];
let mut c2 = vec![0; q + 2];
let mut cu = vec![0; q + 2];
let mut cg1 = vec![0; q + 2];
let mut cg2 = vec![0; q + 2];
let mut ceq = vec![0; q + 2];
) -> (usize, usize, usize)
{
if p < 8 {
joint_mle_dispatch::<u8>(k1, k2, p, q)
} else if p < 16 {
joint_mle_dispatch::<u16>(k1, k2, p, q)
} else {
assert!(p == 16 || p == 17 || p == 18);
joint_mle_dispatch::<u32>(k1, k2, p, q)
}
}

/// Calculate the joint maximum likelihood of A and B.
///
/// Returns a tuple (only in A, only in B, intersection)
fn joint_mle_dispatch<M: MulteplicityInteger>(
k1: &[CounterType],
k2: &[CounterType],
p: usize,
q: usize,
) -> (usize, usize, usize)
where
<M as TryFrom<usize>>::Error: std::fmt::Debug,
{
let mut c1 = vec![M::ZERO; q + 2];
let mut c2 = vec![M::ZERO; q + 2];
let mut cu = vec![M::ZERO; q + 2];
let mut cg1 = vec![M::ZERO; q + 2];
let mut cg2 = vec![M::ZERO; q + 2];
let mut ceq = vec![M::ZERO; q + 2];

for (k1_, k2_) in k1.iter().zip(k2.iter()) {
match k1_.cmp(k2_) {
cmp::Ordering::Less => {
c1[*k1_ as usize] += 1;
cg2[*k2_ as usize] += 1;
c1[*k1_ as usize] += M::ONE;
cg2[*k2_ as usize] += M::ONE;
}
cmp::Ordering::Greater => {
cg1[*k1_ as usize] += 1;
c2[*k2_ as usize] += 1;
cg1[*k1_ as usize] += M::ONE;
c2[*k2_ as usize] += M::ONE;
}
cmp::Ordering::Equal => {
ceq[*k1_ as usize] += 1;
ceq[*k1_ as usize] += M::ONE;
}
}
cu[*cmp::max(k1_, k2_) as usize] += 1;
cu[*cmp::max(k1_, k2_) as usize] += M::ONE;
}

for (i, (v, u)) in cg1.iter().zip(ceq.iter()).enumerate() {
for (i, (&v, &u)) in cg1.iter().zip(ceq.iter()).enumerate() {
c1[i] += v + u;
}

for (i, (v, u)) in cg2.iter().zip(ceq.iter()).enumerate() {
for (i, (&v, &u)) in cg2.iter().zip(ceq.iter()).enumerate() {
c2[i] += v + u;
}

let c_ax = mle(&c1, p, q, 0.01);
let c_bx = mle(&c2, p, q, 0.01);
let c_abx = mle(&cu, p, q, 0.01);

let mut counts_axb_half = vec![0u16; q + 2];
let mut counts_bxa_half = vec![0u16; q + 2];
let mut counts_axb_half = vec![M::ZERO; q + 2];
let mut counts_bxa_half = vec![M::ZERO; q + 2];

counts_axb_half[q] = k1.len() as u16;
counts_bxa_half[q] = k2.len() as u16;
counts_axb_half[q] = M::try_from(k1.len()).unwrap();
counts_bxa_half[q] = M::try_from(k2.len()).unwrap();

for _q in 0..q {
counts_axb_half[_q] = cg1[_q] + ceq[_q] + cg2[_q + 1];
debug_assert!(counts_axb_half[q] >= counts_axb_half[_q]);
counts_axb_half[q] -= counts_axb_half[_q];
let multeplicity_q = counts_axb_half[_q];
counts_axb_half[q] -= multeplicity_q;

counts_bxa_half[_q] = cg2[_q] + ceq[_q] + cg1[_q + 1];
debug_assert!(counts_bxa_half[q] >= counts_bxa_half[_q]);
counts_bxa_half[q] -= counts_bxa_half[_q];
let multeplicity_q = counts_bxa_half[_q];
counts_bxa_half[q] -= multeplicity_q;
}

let c_axb_half = mle(&counts_axb_half, p, q - 1, 0.01);
Expand Down
77 changes: 74 additions & 3 deletions src/core/src/sketch/hyperloglog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,29 @@ impl HyperLogLog {
}

pub fn cardinality(&self) -> usize {
let counts = estimators::counts(&self.registers, self.q);

estimators::mle(&counts, self.p, self.q, 0.01) as usize
if self.p < 8 {
estimators::mle(
&estimators::counts::<u8>(&self.registers, self.q),
self.p,
self.q,
0.01,
) as usize
} else if self.p < 16 {
estimators::mle(
&estimators::counts::<u16>(&self.registers, self.q),
self.p,
self.q,
0.05,
) as usize
} else {
assert!(self.p == 16 || self.p == 17 || self.p == 18);
estimators::mle(
&estimators::counts::<u32>(&self.registers, self.q),
self.p,
self.q,
0.1,
) as usize
}
}

pub fn similarity(&self, other: &HyperLogLog) -> f64 {
Expand Down Expand Up @@ -224,8 +244,10 @@ impl Update<HyperLogLog> for KmerMinHash {
#[cfg(test)]
mod test {
use std::collections::HashSet;
use std::hash::{DefaultHasher, Hash};
use std::io::{BufReader, BufWriter, Read};
use std::path::PathBuf;
use std::hash::Hasher;

use crate::signature::SigsTrait;
use needletail::{parse_fastx_file, parse_fastx_reader, Sequence};
Expand Down Expand Up @@ -374,4 +396,53 @@ mod test {
assert_eq!(hll_new.registers, hll.registers);
assert_eq!(hll_new.ksize, hll.ksize);
}

#[test]
/// Test to cover corner cases in the MLE calculation
/// that may happen at resolutions 16, 17 or 18, i.e.
/// cases with 2^16 == 65536, 2^17 == 131072, 2^18 == 262144.
///
/// In such cases, the MLE multeplicities which were earlier
/// implemented always using a u16 type, may overflow.
fn test_mle_corner_cases() {
for precision in [16, 17, 18] {
let mut hll = HyperLogLog::new(precision, 21).unwrap();
for i in 1..5000 {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
let hash = hasher.finish();
hll.add_hash(hash)
}

let cardinality = hll.cardinality();

assert!(cardinality > 4500 && cardinality < 5500);

// We build a second hll to check whether the union of the two
// hlls is consistent with the cardinality of the union.
let mut hll2 = HyperLogLog::new(precision, 21).unwrap();

for i in 5000..10000 {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
let hash = hasher.finish();
hll2.add_hash(hash)
}

let mut hll_union = hll.clone();
hll_union.merge(&hll2).unwrap();
let cardinality_union = hll_union.cardinality();

assert!(
cardinality_union > 9500 && cardinality_union < 10500,
"precision: {}, cardinality_union: {}",
precision,
cardinality_union
);

let intersection = hll.intersection(&hll2);

assert!(intersection < 500);
}
}
}

0 comments on commit 13b819b

Please sign in to comment.