From 4bae86bef16237ec5c84600b1f5e7a6f67cc35fd Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Tue, 15 Oct 2024 11:06:04 -0700 Subject: [PATCH] MRG: refactor `calculate_gather_stats` to disallow repeated downsampling (#3352) This PR builds on the refactoring in #3342 to do less downsampling and also avoids doing intersections twice (per #3196). Benchmarks in https://github.com/sourmash-bio/sourmash_plugin_branchwater/pull/471 are pretty astonishing... Fixes https://github.com/sourmash-bio/sourmash/issues/3196 --------- Co-authored-by: Luiz Irber --- src/core/src/index/mod.rs | 31 +++++++++--------- src/core/src/index/revindex/disk_revindex.rs | 33 +++++++++++--------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/core/src/index/mod.rs b/src/core/src/index/mod.rs index c71bb5e58..8ed7f63e0 100644 --- a/src/core/src/index/mod.rs +++ b/src/core/src/index/mod.rs @@ -17,7 +17,6 @@ use getset::{CopyGetters, Getters, Setters}; use log::trace; use serde::{Deserialize, Serialize}; use stats::{median, stddev}; -use std::cmp::max; use typed_builder::TypedBuilder; use crate::ani_utils::{ani_ci_from_containment, ani_from_containment}; @@ -28,6 +27,7 @@ use crate::selection::Selection; use crate::signature::SigsTrait; use crate::sketch::minhash::KmerMinHash; use crate::storage::SigStore; +use crate::Error::CannotUpsampleScaled; use crate::Result; #[derive(TypedBuilder, CopyGetters, Getters, Setters, Serialize, Deserialize, Debug, PartialEq)] @@ -209,7 +209,7 @@ where #[allow(clippy::too_many_arguments)] pub fn calculate_gather_stats( orig_query: &KmerMinHash, - query: KmerMinHash, + remaining_query: KmerMinHash, match_sig: SigStore, match_size: usize, gather_result_rank: usize, @@ -218,29 +218,31 @@ pub fn calculate_gather_stats( calc_abund_stats: bool, calc_ani_ci: bool, confidence: Option, -) -> Result { +) -> Result<(GatherResult, (Vec, u64))> { // get match_mh let match_mh = match_sig.minhash().expect("cannot retrieve sketch"); - let max_scaled = max(match_mh.scaled(), query.scaled()); - let query = query - .downsample_scaled(max_scaled) - .expect("cannot downsample query"); + // it's ok to downsample match, but query is often big and repeated, + // so we do not allow downsampling of query in this function. + if match_mh.scaled() > remaining_query.scaled() { + return Err(CannotUpsampleScaled); + } + let match_mh = match_mh .clone() - .downsample_scaled(max_scaled) + .downsample_scaled(remaining_query.scaled()) .expect("cannot downsample match"); // calculate intersection let isect = match_mh - .intersection(&query) + .intersection(&remaining_query) .expect("could not do intersection"); let isect_size = isect.0.len(); trace!("isect_size: {}", isect_size); - trace!("query.size: {}", query.size()); + trace!("query.size: {}", remaining_query.size()); //bp remaining in subtracted query - let remaining_bp = (query.size() - isect_size) * query.scaled() as usize; + let remaining_bp = (remaining_query.size() - isect_size) * remaining_query.scaled() as usize; // stats for this match vs original query let (intersect_orig, _) = match_mh.intersection_size(orig_query).unwrap(); @@ -300,7 +302,7 @@ pub fn calculate_gather_stats( // If abundance, calculate abund-related metrics (vs current query) if calc_abund_stats { // take abunds from subtracted query - let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&query) { + let (abunds, unique_weighted_found) = match match_mh.inflated_abundances(&remaining_query) { Ok((abunds, unique_weighted_found)) => (abunds, unique_weighted_found), Err(e) => { return Err(e); @@ -347,7 +349,7 @@ pub fn calculate_gather_stats( .sum_weighted_found(sum_total_weighted_found) .total_weighted_hashes(total_weighted_hashes) .build(); - Ok(result) + Ok((result, isect)) } #[cfg(test)] @@ -403,7 +405,7 @@ mod test_calculate_gather_stats { let gather_result_rank = 0; let calc_abund_stats = true; let calc_ani_ci = false; - let result = calculate_gather_stats( + let (result, _isect) = calculate_gather_stats( &orig_query, query, match_sig.into(), @@ -416,6 +418,7 @@ mod test_calculate_gather_stats { None, ) .unwrap(); + // first, print all results assert_eq!(result.filename(), "match-filename"); assert_eq!(result.name(), "match-name"); diff --git a/src/core/src/index/revindex/disk_revindex.rs b/src/core/src/index/revindex/disk_revindex.rs index b11ddb168..7386e9ebb 100644 --- a/src/core/src/index/revindex/disk_revindex.rs +++ b/src/core/src/index/revindex/disk_revindex.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; use std::path::Path; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -393,30 +394,28 @@ impl RevIndexOps for RevIndex { } let match_sig = self.collection.sig_for_dataset(dataset_id)?; - - // get downsampled minhashes for comparison. let match_mh = match_sig.minhash().unwrap().clone(); - let scaled = query.scaled(); + + // make downsampled minhashes + let max_scaled = max(match_mh.scaled(), query.scaled()); let match_mh = match_mh - .downsample_scaled(scaled) + .downsample_scaled(max_scaled) .expect("cannot downsample match"); + // repeatedly downsample query, then extract to KmerMinHash + // => calculate_gather_stats + query = query + .downsample_scaled(max_scaled) + .expect("cannot downsample query"); + let query_mh = KmerMinHash::from(query.clone()); + // just calculate essentials here let gather_result_rank = matches.len(); - let query_mh = KmerMinHash::from(query.clone()); - // grab the specific intersection: - let isect = match_mh - .intersection(&query_mh) - .expect("failed to intersect"); - let mut isect_mh = match_mh.clone(); - isect_mh.clear(); - isect_mh.add_many(&isect.0)?; - // Calculate stats - let gather_result = calculate_gather_stats( + let (gather_result, isect) = calculate_gather_stats( &orig_query, query_mh, match_sig, @@ -429,6 +428,12 @@ impl RevIndexOps for RevIndex { ani_confidence_interval_fraction, ) .expect("could not calculate gather stats"); + + // use intersection from calc_gather_stats to make a KmerMinHash. + let mut isect_mh = match_mh.clone(); + isect_mh.clear(); + isect_mh.add_many(&isect.0)?; + // keep track of the sum weighted found sum_weighted_found = gather_result.sum_weighted_found(); matches.push(gather_result);