Skip to content

Commit

Permalink
Use pqdict for efficient reclustering deduplication
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Sep 4, 2024
1 parent dbeb097 commit 900cabf
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 52 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"loguru == 0.7.2",
"pyhmmer == 0.10.12",
"pyrodigal == 3.4.1",
"pqdict == 1.4.0",
]
# Currently pycoverm does not have binaries for Python > 3.12.
# The dependency resolver, will not error on Python 3.13, but attempt
Expand Down
129 changes: 77 additions & 52 deletions vamb/reclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@
from vamb.parsecontigs import CompositionMetaData
from vamb.vambtools import RefHasher
from collections.abc import Sequence, Iterable
from typing import Callable, NewType, Union
import heapq
from typing import NewType, Optional, Union

# TODO: We can get rid of this dep by using the trick from the heapq docs here:
# https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes
# However, this might be a little more tricky to implement
from pqdict import pqdict

# We use these aliases to be able to work with integers, which is faster.
ContigId = NewType("ContigId", int)
BinId = NewType("BinId", int)

# TODO: We might want to benchmark the best value for this constant.
# Right now, we do too much duplicated work by clustering 18 times.
EPS_VALUES = np.arange(0.01, 0.35, 0.02)


# This is a bottleneck, speed wise. Hence, we do a bunch of tricks to remove the
# obviously bad bins to reduce computational cost.
def deduplicate(
scoring: Callable[[Iterable[ContigId]], tuple[float, float]],
bins: dict[BinId, set[ContigId]],
markers: Markers,
) -> list[set[ContigId]]:
"""
deduplicate(f, bins)
Expand All @@ -38,18 +42,16 @@ def deduplicate(
previously been returned from any other bin.
Returns `list[tuple[float, set[ContigId]]]` with scored, disjoint bins.
"""
# This function is a bottleneck, performance wise. Hence, we begin by filtering away some clusters
# that are trivial.
contig_sets = remove_duplicate_bins(bins.values())
scored_contig_sets = remove_badly_contaminated(scoring, contig_sets)
scored_contig_sets = remove_badly_contaminated(contig_sets, markers)
(to_deduplicate, result) = remove_unambigous_bins(scored_contig_sets)
bins = {BinId(i): b for (i, (b, _)) in enumerate(to_deduplicate)}

# Use a heap, such that `heappop` will return the best bin
# (Heaps in Python are min-heaps, so we use the negative score.)
heap = [
(-score_from_comp_cont(s), BinId(i))
for (i, (_, s)) in enumerate(to_deduplicate)
]
heapq.heapify(heap)
# This is a mutable priority queue. It allows us to greedily take the best cluster,
# and then update all the clusters that share contigs with the removed best cluster.
queue = pqdict.maxpq(((BinId(i), s) for (i, (_, s)) in enumerate(to_deduplicate)))

# When removing the best bin, the contigs in that bin must be removed from all
# other bins, which causes some bins's score to be invalidated.
Expand All @@ -61,8 +63,8 @@ def deduplicate(
for ci in c:
bins_of_contig[ci].add(b)

while len(heap) > 0:
(_, best_bin) = heapq.heappop(heap)
while len(queue) > 0:
best_bin = queue.pop()
contigs = bins[best_bin]
result.append(contigs)

Expand All @@ -84,26 +86,26 @@ def deduplicate(
# Remove the bin we picked as the best
del bins[best_bin]

# Check this here to skip recomputing scores and re-heapifying, since that
# takes time.
# TODO: Could possibly skip this sometimes, if we know that any of the recomputed
# has a score lower than the next from the heap (and the heap top is not to be recomputed)
if len(to_recompute) > 0:
heap = [(s, b) for (s, b) in heap if b not in to_recompute]
for bin in to_recompute:
# We could potentially have added some bins in `to_recompute` which have had
# all their members removed.
# These empty bins should be discarded
c = bins.get(bin, None)
if c is not None:
heap.append((-score_from_comp_cont(scoring(c)), bin))
heapq.heapify(heap)
for bin_to_recompute in to_recompute:
contigs_in_rec = bins.get(bin_to_recompute, None)
# If the recomputed bin is now empty, we delete it from queue
if contigs_in_rec is None:
queue.pop(bin_to_recompute)
else:
# We could use the saturating version of the function here,
# but it's unlikely to make a big difference performance wise,
# since the truly terrible bins have been filtered away
counts = count_markers(contigs_in_rec, markers)
new_score = score_from_comp_cont(get_completeness_contamination(counts))
queue.updateitem(bin_to_recompute, new_score)

return result


def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]:
seen_sets: set[frozenset[ContigId]] = set()
# This is just for computational efficiency, so we don't instantiate frozen sets
# with single elements.
seen_singletons: set[ContigId] = set()
for contig_set in sets:
if len(contig_set) == 1:
Expand All @@ -120,22 +122,29 @@ def remove_duplicate_bins(sets: Iterable[set[ContigId]]) -> list[set[ContigId]]:


def remove_badly_contaminated(
scorer: Callable[[Iterable[ContigId]], tuple[float, float]],
sets: Iterable[set[ContigId]],
) -> list[tuple[set[ContigId], tuple[float, float]]]:
result: list[tuple[set[ContigId], tuple[float, float]]] = []
markers: Markers,
) -> list[tuple[set[ContigId], float]]:
result: list[tuple[set[ContigId], float]] = []
max_contamination = 1.0
for contig_set in sets:
(completeness, contamination) = scorer(contig_set)
counts = count_markers_saturated(contig_set, markers)
# None here means that the contamination is already so high we stop counting.
# this is a shortcut for efficiency
if counts is None:
continue
(completeness, contamination) = get_completeness_contamination(counts)
if contamination <= max_contamination:
result.append((contig_set, (completeness, contamination)))
result.append(
(contig_set, score_from_comp_cont((completeness, contamination)))
)
return result


def remove_unambigous_bins(
sets: list[tuple[set[ContigId], tuple[float, float]]],
) -> tuple[list[tuple[set[ContigId], tuple[float, float]]], list[set[ContigId]]]:
"""Remove all bins from d for which all the contigs are only present in that one bin,
sets: list[tuple[set[ContigId], float]],
) -> tuple[list[tuple[set[ContigId], float]], list[set[ContigId]]]:
"""Remove all bins for which all the contigs are only present in that one bin,
and put them in the returned list.
These contigs have a trivial, unambiguous assignment.
"""
Expand All @@ -147,7 +156,7 @@ def remove_unambigous_bins(
in_single_bin[contig] = True
elif existing is True:
in_single_bin[contig] = False
to_deduplicate: list[tuple[set[ContigId], tuple[float, float]]] = []
to_deduplicate: list[tuple[set[ContigId], float]] = []
unambiguous: list[set[ContigId]] = []
for contig_set, scores in sets:
if all(in_single_bin[c] for c in contig_set):
Expand Down Expand Up @@ -296,8 +305,30 @@ def count_markers(
for contig in contigs:
m = markers.markers[contig]
if m is not None:
for i in m:
counts[i] += 1
counts[m] += 1
return counts


# Same as above, but once we see a very high number of marker genes,
# we bail. This is because a large fraction of time spent in this module
# would otherwise be counting markers of huge clusters, long after we already
# know it's hopelessly contaminated
def count_markers_saturated(
contigs: Iterable[ContigId],
markers: Markers,
) -> Optional[np.ndarray]:
counts = np.zeros(markers.n_markers, dtype=np.int32)
# This implies contamination >= 2.0. The actual value depends on the unique
# number of markers, which is too slow to compute in this hot function
max_markers = 3 * markers.n_markers
n_markers = 0
for contig in contigs:
m = markers.markers[contig]
if m is not None:
n_markers += len(m)
counts[m] += 1
if n_markers > max_markers:
return None
return counts


Expand Down Expand Up @@ -358,9 +389,8 @@ def recluster_dbscan(
) -> list[set[ContigId]]:
# Since DBScan is computationally expensive, and scales poorly with the number
# of contigs, we use taxonomy to only cluster within each genus
indices_by_genus = group_indices_by_genus(taxonomy)
result: list[set[ContigId]] = []
for indices in indices_by_genus.values():
for indices in group_indices_by_genus(taxonomy):
genus_latent = latent[indices]
genus_clusters = dbscan_genus(
genus_latent, indices, contiglengths[indices], markers, num_processes
Expand Down Expand Up @@ -409,21 +439,16 @@ def dbscan_genus(
# present in one output bin, using the scoring function to greedily
# output the best of the redundant bins.
bin_dict = {BinId(i): c for (i, c) in enumerate(redundant_bins)}
return deduplicate(
lambda x: get_completeness_contamination(count_markers(x, markers)), bin_dict
)
return deduplicate(bin_dict, markers)


def group_indices_by_genus(
taxonomy: Taxonomy,
) -> dict[str, np.ndarray]:
) -> list[np.ndarray]:
if not taxonomy.is_canonical:
raise ValueError("Can only group by genus for a canonical taxonomy")
by_genus: dict[str, list[ContigId]] = defaultdict(list)
by_genus: dict[Optional[str], list[ContigId]] = defaultdict(list)
for i, tax in enumerate(taxonomy.contig_taxonomies):
genus = None if tax is None else tax.genus
# TODO: Should we also cluster the ones with unknown genus?
# Currently, we just skip it here
if genus is not None:
by_genus[genus].append(ContigId(i))
return {g: np.array(i, dtype=np.int32) for (g, i) in by_genus.items()}
by_genus[genus].append(ContigId(i))
return [np.array(i, dtype=np.int32) for i in by_genus.values()]

0 comments on commit 900cabf

Please sign in to comment.