diff --git a/src/core/src/index/sbt/mhbt.rs b/src/core/src/index/sbt/mhbt.rs index 2d4ceb3fb8..f7c310f8e4 100644 --- a/src/core/src/index/sbt/mhbt.rs +++ b/src/core/src/index/sbt/mhbt.rs @@ -331,6 +331,41 @@ mod test { Ok(()) } + #[test] + #[ignore] + fn find_one_or_any_sbt() -> Result<(), Box> { + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/v5.sbt.json"); + + let sbt = MHBT::from_path(filename)?; + + let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + filename.push("../../tests/test-data/.sbt.v3/60f7e23c24a8d94791cc7a8680c493f9"); + + let mut reader = BufReader::new(File::open(filename)?); + let sigs = Signature::load_signatures( + &mut reader, + Some(31), + Some("DNA".try_into().unwrap()), + None, + )?; + let sig_data = sigs[0].clone(); + + let leaf: SigStore<_> = sig_data.into(); + + let find_results = sbt.find(search_minhashes, &leaf, 0.5)?; + assert_eq!(find_results.len(), 1); + let find_one_results = sbt.find_one(search_minhashes, &leaf, 0.5); + assert!(find_one_results.is_some()); + assert_eq!(find_results[0], find_one_results.unwrap()); + assert!(sbt.find_any(search_minhashes, &leaf, 0.5)); + + assert!(sbt.find_one(|_, _, _| false, &leaf, 0.9).is_none()); + assert!(!sbt.find_any(|_, _, _| false, &leaf, 0.9)); + + Ok(()) + } + #[test] fn scaffold_sbt() { let mut filename = PathBuf::from(env!("CARGO_MANIFEST_DIR")); diff --git a/src/core/src/index/sbt/mod.rs b/src/core/src/index/sbt/mod.rs index 51e4b98788..1d88bc4ccc 100644 --- a/src/core/src/index/sbt/mod.rs +++ b/src/core/src/index/sbt/mod.rs @@ -482,9 +482,7 @@ where sig: &'a L, threshold: f64, ) -> bool { - SBTFindIter::new(self, search_fn, sig, threshold) - .next() - .is_some() + self.find_one(search_fn, sig, threshold).is_some() } /// Finds an element in the SBT that satisfies the provided `search_fn` function for the given threshold. @@ -502,7 +500,7 @@ where sig: &'a L, threshold: f64, ) -> Option<&'a L> { - SBTFindIter::new(self, search_fn, sig, threshold).next() + self.find_iter(search_fn, sig, threshold).next() } }