From 900cabf466ee3c6d5cc686d01abc1779d95e928a Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Fri, 30 Aug 2024 14:28:31 +0200 Subject: [PATCH] Use pqdict for efficient reclustering deduplication --- pyproject.toml | 1 + vamb/reclustering.py | 129 ++++++++++++++++++++++++++----------------- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 82983def..a39bbee9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "loguru == 0.7.2", "pyhmmer == 0.10.12", "pyrodigal == 3.4.1", + "pqdict == 1.4.0", ] # Currently pycoverm does not have binaries for Python > 3.12. # The dependency resolver, will not error on Python 3.13, but attempt diff --git a/vamb/reclustering.py b/vamb/reclustering.py index 45a3bdb4..5261e294 100644 --- a/vamb/reclustering.py +++ b/vamb/reclustering.py @@ -13,21 +13,25 @@ from vamb.parsecontigs import CompositionMetaData from vamb.vambtools import RefHasher from collections.abc import Sequence, Iterable -from typing import Callable, NewType, Union -import heapq +from typing import NewType, Optional, Union + +# TODO: We can get rid of this dep by using the trick from the heapq docs here: +# https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes +# However, this might be a little more tricky to implement +from pqdict import pqdict # We use these aliases to be able to work with integers, which is faster. ContigId = NewType("ContigId", int) BinId = NewType("BinId", int) +# TODO: We might want to benchmark the best value for this constant. +# Right now, we do too much duplicated work by clustering 18 times. EPS_VALUES = np.arange(0.01, 0.35, 0.02) -# This is a bottleneck, speed wise. Hence, we do a bunch of tricks to remove the -# obviously bad bins to reduce computational cost. def deduplicate( - scoring: Callable[[Iterable[ContigId]], tuple[float, float]], bins: dict[BinId, set[ContigId]], + markers: Markers, ) -> list[set[ContigId]]: """ deduplicate(f, bins) @@ -38,18 +42,16 @@ def deduplicate( previously been returned from any other bin. Returns `list[tuple[float, set[ContigId]]]` with scored, disjoint bins. """ + # This function is a bottleneck, performance wise. Hence, we begin by filtering away some clusters + # that are trivial. contig_sets = remove_duplicate_bins(bins.values()) - scored_contig_sets = remove_badly_contaminated(scoring, contig_sets) + scored_contig_sets = remove_badly_contaminated(contig_sets, markers) (to_deduplicate, result) = remove_unambigous_bins(scored_contig_sets) bins = {BinId(i): b for (i, (b, _)) in enumerate(to_deduplicate)} - # Use a heap, such that `heappop` will return the best bin - # (Heaps in Python are min-heaps, so we use the negative score.) - heap = [ - (-score_from_comp_cont(s), BinId(i)) - for (i, (_, s)) in enumerate(to_deduplicate) - ] - heapq.heapify(heap) + # This is a mutable priority queue. It allows us to greedily take the best cluster, + # and then update all the clusters that share contigs with the removed best cluster. + queue = pqdict.maxpq(((BinId(i), s) for (i, (_, s)) in enumerate(to_deduplicate))) # When removing the best bin, the contigs in that bin must be removed from all # other bins, which causes some bins's score to be invalidated. @@ -61,8 +63,8 @@ def deduplicate( for ci in c: bins_of_contig[ci].add(b) - while len(heap) > 0: - (_, best_bin) = heapq.heappop(heap) + while len(queue) > 0: + best_bin = queue.pop() contigs = bins[best_bin] result.append(contigs) @@ -84,26 +86,26 @@ def deduplicate( # Remove the bin we picked as the best del bins[best_bin] - # Check this here to skip recomputing scores and re-heapifying, since that - # takes time. - # TODO: Could possibly skip this sometimes, if we know that any of the recomputed - # has a score lower than the next from the heap (and the heap top is not to be recomputed) - if len(to_recompute) > 0: - heap = [(s, b) for (s, b) in heap if b not in to_recompute] - for bin in to_recompute: - # We could potentially have added some bins in `to_recompute` which have had - # all their members removed. - # These empty bins should be discarded - c = bins.get(bin, None) - if c is not None: - heap.append((-score_from_comp_cont(scoring(c)), bin)) - heapq.heapify(heap) + for bin_to_recompute in to_recompute: + contigs_in_rec = bins.get(bin_to_recompute, None) + # If the recomputed bin is now empty, we delete it from queue + if contigs_in_rec is None: + queue.pop(bin_to_recompute) + else: + # We could use the saturating version of the function here, + # but it's unlikely to make a big difference performance wise, + # since the truly terrible bins have been filtered away + counts = count_markers(contigs_in_rec, markers) + new_score = score_from_comp_cont(get_completeness_contamination(counts)) + queue.updateitem(bin_to_recompute, new_score) return result def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]: seen_sets: set[frozenset[ContigId]] = set() + # This is just for computational efficiency, so we don't instantiate frozen sets + # with single elements. seen_singletons: set[ContigId] = set() for contig_set in sets: if len(contig_set) == 1: @@ -120,22 +122,29 @@ def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]: def remove_badly_contaminated( - scorer: Callable[[Iterable[ContigId]], tuple[float, float]], sets: Iterable[set[ContigId]], -) -> list[tuple[set[ContigId], tuple[float, float]]]: - result: list[tuple[set[ContigId], tuple[float, float]]] = [] + markers: Markers, +) -> list[tuple[set[ContigId], float]]: + result: list[tuple[set[ContigId], float]] = [] max_contamination = 1.0 for contig_set in sets: - (completeness, contamination) = scorer(contig_set) + counts = count_markers_saturated(contig_set, markers) + # None here means that the contamination is already so high we stop counting. + # this is a shortcut for efficiency + if counts is None: + continue + (completeness, contamination) = get_completeness_contamination(counts) if contamination <= max_contamination: - result.append((contig_set, (completeness, contamination))) + result.append( + (contig_set, score_from_comp_cont((completeness, contamination))) + ) return result def remove_unambigous_bins( - sets: list[tuple[set[ContigId], tuple[float, float]]], -) -> tuple[list[tuple[set[ContigId], tuple[float, float]]], list[set[ContigId]]]: - """Remove all bins from d for which all the contigs are only present in that one bin, + sets: list[tuple[set[ContigId], float]], +) -> tuple[list[tuple[set[ContigId], float]], list[set[ContigId]]]: + """Remove all bins for which all the contigs are only present in that one bin, and put them in the returned list. These contigs have a trivial, unambiguous assignment. """ @@ -147,7 +156,7 @@ def remove_unambigous_bins( in_single_bin[contig] = True elif existing is True: in_single_bin[contig] = False - to_deduplicate: list[tuple[set[ContigId], tuple[float, float]]] = [] + to_deduplicate: list[tuple[set[ContigId], float]] = [] unambiguous: list[set[ContigId]] = [] for contig_set, scores in sets: if all(in_single_bin[c] for c in contig_set): @@ -296,8 +305,30 @@ def count_markers( for contig in contigs: m = markers.markers[contig] if m is not None: - for i in m: - counts[i] += 1 + counts[m] += 1 + return counts + + +# Same as above, but once we see a very high number of marker genes, +# we bail. This is because a large fraction of time spent in this module +# would otherwise be counting markers of huge clusters, long after we already +# know it's hopelessly contaminated +def count_markers_saturated( + contigs: Iterable[ContigId], + markers: Markers, +) -> Optional[np.ndarray]: + counts = np.zeros(markers.n_markers, dtype=np.int32) + # This implies contamination >= 2.0. The actual value depends on the unique + # number of markers, which is too slow to compute in this hot function + max_markers = 3 * markers.n_markers + n_markers = 0 + for contig in contigs: + m = markers.markers[contig] + if m is not None: + n_markers += len(m) + counts[m] += 1 + if n_markers > max_markers: + return None return counts @@ -358,9 +389,8 @@ def recluster_dbscan( ) -> list[set[ContigId]]: # Since DBScan is computationally expensive, and scales poorly with the number # of contigs, we use taxonomy to only cluster within each genus - indices_by_genus = group_indices_by_genus(taxonomy) result: list[set[ContigId]] = [] - for indices in indices_by_genus.values(): + for indices in group_indices_by_genus(taxonomy): genus_latent = latent[indices] genus_clusters = dbscan_genus( genus_latent, indices, contiglengths[indices], markers, num_processes @@ -409,21 +439,16 @@ def dbscan_genus( # present in one output bin, using the scoring function to greedily # output the best of the redundant bins. bin_dict = {BinId(i): c for (i, c) in enumerate(redundant_bins)} - return deduplicate( - lambda x: get_completeness_contamination(count_markers(x, markers)), bin_dict - ) + return deduplicate(bin_dict, markers) def group_indices_by_genus( taxonomy: Taxonomy, -) -> dict[str, np.ndarray]: +) -> list[np.ndarray]: if not taxonomy.is_canonical: raise ValueError("Can only group by genus for a canonical taxonomy") - by_genus: dict[str, list[ContigId]] = defaultdict(list) + by_genus: dict[Optional[str], list[ContigId]] = defaultdict(list) for i, tax in enumerate(taxonomy.contig_taxonomies): genus = None if tax is None else tax.genus - # TODO: Should we also cluster the ones with unknown genus? - # Currently, we just skip it here - if genus is not None: - by_genus[genus].append(ContigId(i)) - return {g: np.array(i, dtype=np.int32) for (g, i) in by_genus.items()} + by_genus[genus].append(ContigId(i)) + return [np.array(i, dtype=np.int32) for i in by_genus.values()]