From 15638ffa8f305618a8336e596d12e7b254ddff7c Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 7 Nov 2023 09:16:59 +0100 Subject: [PATCH] Refactor binsplitting This is a larger change which overhauls how binsplitting is done, and, as a consequence, reworks some of the overall workflow in `__main__.py`. The PR is intended to address the following problems: * Before, we only output either the binsplit clusters, or the unsplit clusters. This is problematic, because we know the binsplit clusters are the best ones, so we would like to output these. However, the unsplit ones contain important information about the source cluster, which powerusers need to be able to recover. - Now, we output both `_split.tsv` and `_unsplit.tsv` files, if binsplitting takes place. * Before, we defaulted to no binsplitting, even as we know it was inferior - Now, `-o C` is default. * Before, if a user passed in a wrong binsplit separator, Vamb would not error until the clustering step, and the error message would be inscrutable - Now, error already when parsing the contigs, EXCEPT if the binsplit sep has defaulted to 'C', in which case binsplitting is disabled, and the user is warned - The error message is significantly improved and more explanatory * Before, the logic of where binsplitting happened was ad-hoc, and scattered all over the place. For example, binsplitting took place during cluster writing, during bin writing, during benchmarking, during clustering itself, and immediately after clustering. It was also implemented multiple places. - Now, create a `BinSplitter` class responsible for binsplitting. The writer functions and loader functions do not binsplit. - Now, binsplitting mostly takes place immediately before writing the split clusters meaning the clusters are unambiguously unsplit for the majority of the program --- .github/workflows/cli_vamb.yml | 2 +- test/test_parsecontigs.py | 32 +- test/test_results.py | 5 +- test/test_vambtools.py | 127 +++---- vamb/__main__.py | 485 +++++++++++---------------- vamb/parsecontigs.py | 22 -- vamb/vambtools.py | 338 ++++++++++--------- workflow_avamb/avamb.snake.conda.smk | 12 +- 8 files changed, 456 insertions(+), 567 deletions(-) diff --git a/.github/workflows/cli_vamb.yml b/.github/workflows/cli_vamb.yml index 6b86d6d9..686fb90a 100644 --- a/.github/workflows/cli_vamb.yml +++ b/.github/workflows/cli_vamb.yml @@ -58,6 +58,6 @@ jobs: cat outdir_taxometer/log.txt - name: Run k-means reclustering run: | - vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --rpkm abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000 + vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --rpkm abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000 ls -la outdir_recluster cat outdir_recluster/log.txt diff --git a/test/test_parsecontigs.py b/test/test_parsecontigs.py index 68a6bd38..753f3f80 100644 --- a/test/test_parsecontigs.py +++ b/test/test_parsecontigs.py @@ -2,7 +2,6 @@ import unittest import random import numpy as np -import warnings import testtools from vamb.parsecontigs import Composition, CompositionMetaData @@ -41,26 +40,9 @@ def test_unique_names(self): 1000, ) - # Does not warn - def test_nowarn(self): - with warnings.catch_warnings(): - warnings.simplefilter("error", UserWarning) - Composition.from_file(self.large_io, minlength=250) - - def test_warns_n_contigs(self): - with self.assertWarns(UserWarning): - Composition.from_file(self.io, minlength=250) - - def test_warns_minlength(self): - with self.assertWarns(UserWarning): - Composition.from_file(self.large_io, minlength=275) - def test_filter_minlength(self): minlen = 500 - - with self.assertWarns(UserWarning): - composition = Composition.from_file(self.io, minlength=450) - + composition = Composition.from_file(self.io, minlength=450) md = composition.metadata hash1 = md.refhash @@ -99,8 +81,7 @@ def test_minlength(self): Composition.from_file(self.io, minlength=3) def test_properties(self): - with self.assertWarns(UserWarning): - composition = Composition.from_file(self.io, minlength=420) + composition = Composition.from_file(self.io, minlength=420) passed = list(filter(lambda x: len(x.sequence) >= 420, self.records)) self.assertEqual(composition.nseqs, len(composition.metadata.identifiers)) @@ -122,8 +103,7 @@ def test_properties(self): def test_save_load(self): buf = io.BytesIO() - with self.assertWarns(UserWarning): - composition_1 = Composition.from_file(self.io) + composition_1 = Composition.from_file(self.io) md1 = composition_1.metadata composition_1.save(buf) buf.seek(0) @@ -153,10 +133,8 @@ def test_windows_newlines(self): buf1.seek(0) buf2.seek(0) - with self.assertWarns(UserWarning): - comp1 = Composition.from_file(buf1) - with self.assertWarns(UserWarning): - comp2 = Composition.from_file(buf2) + comp1 = Composition.from_file(buf1) + comp2 = Composition.from_file(buf2) self.assertEqual(comp1.metadata.refhash, comp2.metadata.refhash) self.assertTrue(np.all(comp1.matrix == comp2.matrix)) diff --git a/test/test_results.py b/test/test_results.py index 59fce4ea..d38ccf55 100644 --- a/test/test_results.py +++ b/test/test_results.py @@ -30,9 +30,8 @@ def setUp(self): self.io.seek(0) def test_runs(self): - with self.assertRaises(UserWarning): - comp = vamb.parsecontigs.Composition.from_file(self.io) - self.assertIsInstance(comp, vamb.parsecontigs.Composition) + comp = vamb.parsecontigs.Composition.from_file(self.io) + self.assertIsInstance(comp, vamb.parsecontigs.Composition) if TEST_UNSTABLE_HASHES: diff --git a/test/test_vambtools.py b/test/test_vambtools.py index 837ee36b..fcd5fb84 100644 --- a/test/test_vambtools.py +++ b/test/test_vambtools.py @@ -10,8 +10,11 @@ import numpy as np import string import torch +import pathlib +import shutil import vamb +from vamb.vambtools import BinSplitter import testtools PARENTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -220,7 +223,7 @@ def test_bad_files(self): # String input with self.assertRaises(TypeError): data = ">abc\nTAG\na\nAC".splitlines() - list(vamb.vambtools.byte_iterfasta(data)) + list(vamb.vambtools.byte_iterfasta(data)) # type:ignore # Various correct formats def test_good_files(self): @@ -379,16 +382,48 @@ class TestBinSplit(unittest.TestCase): ("s12-bin2", {"s12-c0"}), ] + def test_inert(self): + self.assertEqual( + BinSplitter("").splitter, BinSplitter.inert_splitter().splitter + ) + def test_split(self): - self.assertEqual(list(vamb.vambtools.binsplit(self.before, "-")), self.after) + self.assertEqual(list(BinSplitter("-").binsplit(self.before)), self.after) def test_badsep(self): with self.assertRaises(KeyError): - list(vamb.vambtools.binsplit(self.before, "2")) + list(BinSplitter("2").binsplit(self.before)) def test_badtype(self): - with self.assertRaises(TypeError): - list(vamb.vambtools.binsplit([(1, [2])], "")) + with self.assertRaises(Exception): + list(BinSplitter("x").binsplit([(1, [2])])) # type:ignore + + def test_nosplit(self): + self.assertEqual( + list(BinSplitter("").binsplit(self.before)), + [(k, set(s)) for (k, s) in self.before], + ) + + def test_initialize(self): + # Nothing happens to an inert splitter + b = BinSplitter.inert_splitter() + s = b.splitter + b.initialize([""]) + self.assertEqual(s, b.splitter) + + b = BinSplitter("X") + with self.assertRaises(ValueError): + b.initialize(["AXC", "S1C2"]) + + b = BinSplitter(None) + with self.assertWarns(UserWarning): + b.initialize(["S1C2", "ABC"]) + + b = BinSplitter(None) + b.initialize(["S1C2", "KMCPLK"]) + + b = BinSplitter("XYZ") + b.initialize(["ABXYZCD", "KLMXYZA"]) class TestConcatenateFasta(unittest.TestCase): @@ -459,8 +494,8 @@ def setUp(self): self.io.truncate(0) self.io.seek(0) - def linesof(self, str): - return list(filter(lambda x: not x.startswith("#"), str.splitlines())) + def linesof(self, string: str): + return list(filter(lambda x: not x.startswith("#"), string.splitlines())) def conforms(self, str, clusters): lines = self.linesof(str) @@ -488,46 +523,29 @@ def conforms(self, str, clusters): self.assertEqual(read_names, printed_names) def test_not_writable(self): - buf = io.BufferedReader(io.BytesIO(b"")) + buf = io.BufferedReader(io.BytesIO(b"")) # type:ignore with self.assertRaises(ValueError): - vamb.vambtools.write_clusters(buf, self.test_clusters) - - def test_invalid_max_clusters(self): - with self.assertRaises(ValueError): - vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=0) - - with self.assertRaises(ValueError): - vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=-11) - - def test_header_has_newline(self): - with self.assertRaises(ValueError): - vamb.vambtools.write_clusters(self.io, self.test_clusters, header="foo\n") + vamb.vambtools.write_clusters(buf, self.test_clusters) # type:ignore def test_normal(self): - vamb.vambtools.write_clusters(self.io, self.test_clusters, header="someheader") - self.assertTrue(self.io.getvalue().startswith("# someheader")) + vamb.vambtools.write_clusters(self.io, self.test_clusters) self.conforms(self.io.getvalue(), self.test_clusters) def test_max_clusters(self): - vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=2) + vamb.vambtools.write_clusters(self.io, self.test_clusters[:2]) lines = self.linesof(self.io.getvalue()) self.assertEqual(len(lines), 9) self.conforms(self.io.getvalue(), self.test_clusters[:2]) - def test_min_size(self): - vamb.vambtools.write_clusters(self.io, self.test_clusters, min_size=5) - lines = self.linesof(self.io.getvalue()) - self.assertEqual(len(lines), 6) - self.conforms(self.io.getvalue(), self.test_clusters[1:2]) - class TestWriteBins(unittest.TestCase): file = io.BytesIO() N_BINS = 10 - minsize = 5 * 175 # mean of bin size - dirname = os.path.join( - tempfile.gettempdir(), - "".join(random.choices(string.ascii_letters + string.digits, k=10)), + dir = pathlib.Path( + os.path.join( + tempfile.gettempdir(), + "".join(random.choices(string.ascii_letters + string.digits, k=10)), + ) ) @classmethod @@ -552,7 +570,7 @@ def setUp(self): def tearDown(self): try: - os.rmdir(self.dirname) + shutil.rmtree(self.dir) except FileNotFoundError: pass @@ -560,26 +578,26 @@ def test_bad_params(self): # Too many bins for maxbins with self.assertRaises(ValueError): vamb.vambtools.write_bins( - self.dirname, self.bins, self.file, maxbins=self.N_BINS - 1 - ) - - # Negative minsize - with self.assertRaises(ValueError): - vamb.vambtools.write_bins( - self.dirname, self.bins, self.file, maxbins=self.N_BINS + 1, minsize=-1 + self.dir, self.bins, self.file, maxbins=self.N_BINS - 1 ) # Parent does not exist with self.assertRaises(NotADirectoryError): vamb.vambtools.write_bins( - "svogew/foo", self.bins, self.file, maxbins=self.N_BINS + 1 + pathlib.Path("svogew/foo"), + self.bins, + self.file, + maxbins=self.N_BINS + 1, ) - # Target file already exists + # Target is an existing file with self.assertRaises(FileExistsError): with tempfile.NamedTemporaryFile() as file: vamb.vambtools.write_bins( - file.name, self.bins, self.file, maxbins=self.N_BINS + 1 + pathlib.Path(file.name), + self.bins, + self.file, + maxbins=self.N_BINS + 1, ) # One contig missing from fasta dict @@ -587,36 +605,31 @@ def test_bad_params(self): bins = {k: v.copy() for k, v in self.bins.items()} next(iter(bins.values())).add("a_new_bin_which_does_not_exist") vamb.vambtools.write_bins( - self.dirname, bins, self.file, maxbins=self.N_BINS + 1 + self.dir, bins, self.file, maxbins=self.N_BINS + 1 ) def test_round_trip(self): with tempfile.TemporaryDirectory() as dir: vamb.vambtools.write_bins( - dir, self.bins, self.file, maxbins=self.N_BINS, minsize=self.minsize + pathlib.Path(dir), + self.bins, + self.file, + maxbins=self.N_BINS, ) reconstructed_bins: dict[str, set[str]] = dict() for filename in os.listdir(dir): with open(os.path.join(dir, filename), "rb") as file: entries = list(vamb.vambtools.byte_iterfasta(file)) - if sum(map(len, entries)) < self.minsize: - continue binname = filename[:-4] reconstructed_bins[binname] = set() for entry in entries: reconstructed_bins[binname].add(entry.identifier) - filtered_bins = { - k: v - for (k, v) in self.bins.items() - if sum(len(self.seqs[vi]) for vi in v) >= self.minsize - } - # Same bins - self.assertEqual(len(filtered_bins), len(reconstructed_bins)) + self.assertEqual(len(self.bins), len(reconstructed_bins)) self.assertEqual( - sum(map(len, filtered_bins.values())), + sum(map(len, self.bins.values())), sum(map(len, reconstructed_bins.values())), ) - self.assertEqual(filtered_bins, reconstructed_bins) + self.assertEqual(self.bins, reconstructed_bins) diff --git a/vamb/__main__.py b/vamb/__main__.py index aa28b099..6fc8bbe1 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -12,6 +12,7 @@ import time import random import pycoverm +import itertools from math import isfinite from typing import Optional, IO, Tuple, Union from pathlib import Path @@ -53,14 +54,19 @@ class CompositionPath(type(Path())): class CompositionOptions: - __slots__ = ["path", "min_contig_length"] + __slots__ = ["path", "min_contig_length", "warn_on_few_seqs"] def __init__( - self, fastapath: Optional[Path], npzpath: Optional[Path], min_contig_length: int + self, + fastapath: Optional[Path], + npzpath: Optional[Path], + min_contig_length: int, + warn_on_few_seqs: bool, ): assert isinstance(fastapath, (Path, type(None))) assert isinstance(npzpath, (Path, type(None))) assert isinstance(min_contig_length, int) + assert isinstance(warn_on_few_seqs, bool) if min_contig_length < 250: raise argparse.ArgumentTypeError( @@ -82,6 +88,7 @@ def __init__( assert npzpath is not None self.path = CompositionPath(npzpath) self.min_contig_length = min_contig_length + self.warn_on_few_seqs = warn_on_few_seqs class AbundancePath(type(Path())): @@ -247,7 +254,7 @@ class ReclusteringOptions: "latent_path", "clusters_path", "hmmout_path", - "binsplit_separator", + "binsplitter", "algorithm", ] @@ -272,9 +279,9 @@ def __init__( self.latent_path = latent_path self.clusters_path = clusters_path - self.binsplit_separator = binsplit_separator self.hmmout_path = hmmout_path self.algorithm = algorithm + self.binsplitter = vamb.vambtools.BinSplitter(binsplit_separator) class EncoderOptions: @@ -411,22 +418,19 @@ class ClusterOptions: __slots__ = [ "window_size", "min_successes", - "min_cluster_size", "max_clusters", - "binsplit_separator", + "binsplitter", ] def __init__( self, window_size: int, min_successes: int, - min_cluster_size: int, max_clusters: Optional[int], binsplit_separator: Optional[str], ): assert isinstance(window_size, int) assert isinstance(min_successes, int) - assert isinstance(min_cluster_size, int) assert isinstance(max_clusters, (int, type(None))) assert isinstance(binsplit_separator, (str, type(None))) @@ -434,10 +438,6 @@ def __init__( raise argparse.ArgumentTypeError("Window size must be at least 1") self.window_size = window_size - if min_cluster_size < 1: - raise argparse.ArgumentTypeError("Minimum cluster size must be at least 0") - self.min_cluster_size = min_cluster_size - if min_successes < 1 or min_successes > window_size: raise argparse.ArgumentTypeError( "Minimum cluster size must be in 1:windowsize" @@ -447,12 +447,7 @@ def __init__( if max_clusters is not None and max_clusters < 1: raise argparse.ArgumentTypeError("Max clusters must be at least 1") self.max_clusters = max_clusters - - if binsplit_separator is not None and len(binsplit_separator) == 0: - raise argparse.ArgumentTypeError( - "Binsplit separator cannot be an empty string" - ) - self.binsplit_separator = binsplit_separator + self.binsplitter = vamb.vambtools.BinSplitter(binsplit_separator) class VambOptions: @@ -523,6 +518,7 @@ def log(string: str, logfile: IO[str], indent: int = 0): def calc_tnf( options: CompositionOptions, outdir: Path, + binsplitter: vamb.vambtools.BinSplitter, logfile: IO[str], ) -> vamb.parsecontigs.Composition: begintime = time.time() @@ -544,6 +540,30 @@ def calc_tnf( ) composition.save(outdir.joinpath("composition.npz")) + binsplitter.initialize(composition.metadata.identifiers) + + if options.warn_on_few_seqs and composition.nseqs < 20_000: + message = ( + f"WARNING: Kept only {composition.nseqs} sequences from FASTA file. " + "We normally expect 20,000 sequences or more to prevent overfitting. " + "As a deep learning model, VAEs are prone to overfitting with too few sequences. " + "You may want to bin more samples as a time, lower the beta parameter, " + "or use a different binner altogether." + ) + vamb.vambtools.log_and_warn(message, logfile=logfile) + + # Warn the user if any contigs have been observed, which is smaller + # than the threshold. + if not np.all(composition.metadata.mask): + n_removed = len(composition.metadata.mask) - np.sum(composition.metadata.mask) + message = ( + f"WARNING: The minimum sequence length has been set to {options.min_contig_length}, " + f"but {n_removed} sequences fell below this threshold and was filtered away." + "\nBetter results are obtained if the sequence file is filtered to the minimum " + "sequence length before mapping." + ) + vamb.vambtools.log_and_warn(message, logfile=logfile) + elapsed = round(time.time() - begintime, 2) print("", file=logfile) log( @@ -712,18 +732,19 @@ def trainaae( return latent, clusters_y_dict -def cluster( +def cluster_and_write_files( + vamb_options: VambOptions, cluster_options: ClusterOptions, - clusterspath: Path, + base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv + bins_dir: Path, + fasta_catalogue: Path, latent: np.ndarray, - contignames: Sequence[str], # of dtype object - lengths: Sequence[int], # of dtype object - vamb_options: VambOptions, + sequence_names: Sequence[str], + sequence_lens: Sequence[int], logfile: IO[str], - cluster_prefix: str, -) -> None: +): begintime = time.time() - + # Create cluser iterator log("\nClustering", logfile) log(f"Windowsize: {cluster_options.window_size}", logfile, 1) log( @@ -732,127 +753,116 @@ def cluster( 1, ) log(f"Max clusters: {cluster_options.max_clusters}", logfile, 1) - log(f"Min cluster size: {cluster_options.min_cluster_size}", logfile, 1) log(f"Use CUDA for clustering: {vamb_options.cuda}", logfile, 1) - log( - "Separator: {}".format( - None - if cluster_options.binsplit_separator is None - else ('"' + cluster_options.binsplit_separator + '"') - ), - logfile, - 1, - ) + log(f"Binsplitter: {cluster_options.binsplitter.log_string()}", logfile, 1) cluster_generator = vamb.cluster.ClusterGenerator( latent, - lengths, # type:ignore + sequence_lens, # type:ignore windowsize=cluster_options.window_size, minsuccesses=cluster_options.min_successes, destroy=True, normalized=False, - # cuda=vamb_options.cuda, - cuda=False, # disabled until clustering is fixed + cuda=vamb_options.cuda, rng_seed=vamb_options.seed, ) renamed = ( - (str(cluster_index + 1), {contignames[i] for i in members}) + (str(cluster_index + 1), {sequence_names[i] for i in members}) for (cluster_index, (_, members)) in enumerate( map(lambda x: x.as_tuple(), cluster_generator) ) ) - # Binsplit if given a separator - if cluster_options.binsplit_separator is not None: - maybe_split = vamb.vambtools.binsplit( - renamed, cluster_options.binsplit_separator - ) - else: - maybe_split = renamed - - with open(clusterspath, "w") as clustersfile: - clusternumber, ncontigs = vamb.vambtools.write_clusters( - clustersfile, - maybe_split, - max_clusters=cluster_options.max_clusters, - min_size=cluster_options.min_cluster_size, - rename=False, - cluster_prefix=cluster_prefix, - ) - - print("", file=logfile) - log(f"Clustered {ncontigs} contigs in {clusternumber} bins", logfile, 1) - + # This also works correctly when max_clusters is None + first_clusters = itertools.islice(renamed, cluster_options.max_clusters) + unsplit_clusters = dict(first_clusters) elapsed = round(time.time() - begintime, 2) log(f"Clustered contigs in {elapsed} seconds.", logfile, 1) + write_clusters_and_bins( + vamb_options, + cluster_options.binsplitter, + base_clusters_name, + bins_dir, + fasta_catalogue, + unsplit_clusters, + sequence_names, + sequence_lens, + logfile, + ) -def write_fasta( - outdir: Path, - clusterspath: Path, - fastapath: Path, - contignames: Sequence[str], # of object - contiglengths: np.ndarray, - minfasta: int, + +def write_clusters_and_bins( + vamb_options: VambOptions, + binsplitter: vamb.vambtools.BinSplitter, + base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv + bins_dir: Path, + fasta_catalogue: Path, + clusters: dict[str, set[str]], + sequence_names: Sequence[str], + sequence_lens: Sequence[int], logfile: IO[str], - separator: Optional[str], -) -> None: +): + # Write unsplit clusters to file begintime = time.time() - - log("\nWriting FASTA files", logfile) - log("Minimum FASTA size: " + str(minfasta), logfile, 1) - assert len(contignames) == len(contiglengths) - - lengthof = dict(zip(contignames, contiglengths)) - filtered_clusters: dict[str, set[str]] = dict() - - with open(clusterspath) as file: - clusters = vamb.vambtools.read_clusters(file) - - for cluster, contigs in clusters.items(): - size = sum(lengthof[contig] for contig in contigs) - if size >= minfasta: - filtered_clusters[cluster] = clusters[cluster] - - del lengthof, clusters - keep: set[str] = set() - for contigs in filtered_clusters.values(): - keep.update(set(contigs)) - - with vamb.vambtools.Reader(fastapath) as file: - vamb.vambtools.write_bins( - outdir.joinpath("bins"), - filtered_clusters, - file, - maxbins=None, - separator=separator, + unsplit_path = Path(base_clusters_name + "_unsplit.tsv") + with open(unsplit_path, "w") as file: + (n_unsplit_clusters, n_contigs) = vamb.vambtools.write_clusters( + file, clusters.items() ) - ncontigs = sum(map(len, filtered_clusters.values())) - nfiles = len(filtered_clusters) - print("", file=logfile) - log(f"Wrote {ncontigs} contigs to {nfiles} FASTA files", logfile, 1) + # Open unsplit clusters and split them + if binsplitter.splitter is not None: + split_path = Path(base_clusters_name + "_split.tsv") + clusters = dict(binsplitter.binsplit(clusters.items())) + with open(split_path, "w") as file: + (n_split_clusters, _) = vamb.vambtools.write_clusters( + file, clusters.items() + ) + msg = f"Clustered {n_contigs} contigs in {n_split_clusters} split bins ({n_unsplit_clusters} clusters)" + else: + msg = f"Clustered {n_contigs} contigs in {n_unsplit_clusters} unsplit bins" + log("\n" + msg, logfile, 1) elapsed = round(time.time() - begintime, 2) - log(f"Wrote FASTA in {elapsed} seconds.", logfile, 1) + log(f"Wrote clusters file(s) in {elapsed} seconds.", logfile, 1) + + # Write bins, if necessary + if vamb_options.min_fasta_output_size is not None: + starttime = time.time() + filtered_clusters: dict[str, set[str]] = dict() + assert len(sequence_lens) == len(sequence_names) + sizeof = dict(zip(sequence_names, sequence_lens)) + for binname, contigs in clusters.items(): + if sum(sizeof[c] for c in contigs) >= vamb_options.min_fasta_output_size: + filtered_clusters[binname] = contigs + + with vamb.vambtools.Reader(fasta_catalogue) as file: + vamb.vambtools.write_bins( + bins_dir, + filtered_clusters, + file, + None, + ) + elapsed = round(time.time() - starttime, 2) + n_bins = len(filtered_clusters) + n_contigs = sum(len(v) for v in filtered_clusters.values()) + log( + f"\nWrote {n_bins} bins with {n_contigs} sequences in {elapsed} seconds.", + logfile, + 1, + ) def load_composition_and_abundance( vamb_options: VambOptions, comp_options: CompositionOptions, abundance_options: AbundanceOptions, + binsplitter: vamb.vambtools.BinSplitter, logfile: IO[str], ) -> Tuple[vamb.parsecontigs.Composition, vamb.parsebam.Abundance]: - log("Starting Vamb version " + ".".join(map(str, vamb.__version__)), logfile) - log("Date and time is " + str(datetime.datetime.now()), logfile, 1) - log("Random seed is " + str(vamb_options.seed), logfile, 1) - begintime = time.time() - - # Get TNFs, save as npz - composition = calc_tnf(comp_options, vamb_options.out_dir, logfile) - - # Parse BAMs, save as npz + composition = calc_tnf(comp_options, vamb_options.out_dir, binsplitter, logfile) abundance = calc_rpkm( abundance_options, vamb_options.out_dir, @@ -860,12 +870,6 @@ def load_composition_and_abundance( vamb_options.n_threads, logfile, ) - time_generating_input = round(time.time() - begintime, 2) - log( - f"\nTNF and coabundances generated in {time_generating_input} seconds.", - logfile, - 1, - ) return (composition, abundance) @@ -883,12 +887,11 @@ def run( vae_training_options = training_options.vae_options aae_training_options = training_options.aae_options - begintime = time.time() - composition, abundance = load_composition_and_abundance( vamb_options=vamb_options, comp_options=comp_options, abundance_options=abundance_options, + binsplitter=cluster_options.binsplitter, logfile=logfile, ) data_loader = vamb.encode.make_dataloader( @@ -947,119 +950,48 @@ def run( if vae_options is not None: assert latent is not None assert comp_metadata.nseqs == len(latent) - # Cluster, save tsv file - clusterspath = vamb_options.out_dir.joinpath("vae_clusters.tsv") - cluster( + cluster_and_write_files( + vamb_options, cluster_options, - clusterspath, + str(vamb_options.out_dir.joinpath("vae_clusters")), + vamb_options.out_dir.joinpath("bins"), + comp_options.path, latent, comp_metadata.identifiers, # type:ignore comp_metadata.lengths, # type:ignore - vamb_options, logfile, - "vae_", ) - log("VAE latent clustered", logfile, 1) - del latent - fin_cluster_latent = time.time() - - if vamb_options.min_fasta_output_size is not None: - path = comp_options.path - assert isinstance(path, FASTAPath) - write_fasta( - vamb_options.out_dir, - clusterspath, - path, - comp_metadata.identifiers, # type:ignore - comp_metadata.lengths, - vamb_options.min_fasta_output_size, - logfile, - separator=cluster_options.binsplit_separator, - ) - - writing_bins_time = round(time.time() - fin_cluster_latent, 2) - log(f"VAE bins written in {writing_bins_time} seconds.", logfile, 1) if aae_options is not None: assert latent_z is not None assert clusters_y_dict is not None assert comp_metadata.nseqs == len(latent_z) - # Cluster, save tsv file - clusterspath = vamb_options.out_dir.joinpath("aae_z_clusters.tsv") - - cluster( + cluster_and_write_files( + vamb_options, cluster_options, - clusterspath, + str(vamb_options.out_dir.joinpath("aae_z_clusters")), + vamb_options.out_dir.joinpath("bins"), + comp_options.path, latent_z, comp_metadata.identifiers, # type:ignore comp_metadata.lengths, # type:ignore - vamb_options, logfile, - "aae_z_", ) - - fin_cluster_latent_z = time.time() - log("AAE z latent clustered.", logfile, 1) - del latent_z - - if vamb_options.min_fasta_output_size is not None: - path = comp_options.path - assert isinstance(path, FASTAPath) - write_fasta( - vamb_options.out_dir, - clusterspath, - path, - comp_metadata.identifiers, # type:ignore - comp_metadata.lengths, - vamb_options.min_fasta_output_size, - logfile, - separator=cluster_options.binsplit_separator, - ) - time_writing_bins_z = time.time() - writing_bins_time_z = round(time_writing_bins_z - fin_cluster_latent_z, 2) - log(f"AAE z bins written in {writing_bins_time_z} seconds.", logfile, 1) - - clusterspath = vamb_options.out_dir.joinpath("aae_y_clusters.tsv") - # Binsplit if given a separator - if cluster_options.binsplit_separator is not None: - maybe_split = vamb.vambtools.binsplit( - clusters_y_dict.items(), cluster_options.binsplit_separator - ) - else: - maybe_split = clusters_y_dict.items() - with open(clusterspath, "w") as clustersfile: - clusternumber, ncontigs = vamb.vambtools.write_clusters( - clustersfile, - maybe_split, - max_clusters=cluster_options.max_clusters, - min_size=cluster_options.min_cluster_size, - rename=False, - cluster_prefix="aae_y_", - ) - - print("", file=logfile) - log(f"Clustered {ncontigs} contigs in {clusternumber} bins", logfile, 1) - time_start_writin_z_bins = time.time() - if vamb_options.min_fasta_output_size is not None: - path = comp_options.path - assert isinstance(path, FASTAPath) - write_fasta( - vamb_options.out_dir, - clusterspath, - path, - comp_metadata.identifiers, # type:ignore - comp_metadata.lengths, - vamb_options.min_fasta_output_size, - logfile, - separator=cluster_options.binsplit_separator, - ) - time_writing_bins_y = time.time() - writing_bins_time_y = round(time_writing_bins_y - time_start_writin_z_bins, 2) - log(f"AAE y bins written in {writing_bins_time_y} seconds.", logfile, 1) - - log(f"\nCompleted Vamb in {round(time.time() - begintime, 2)} seconds.", logfile, 0) + # We enforce this in the VAEAAEOptions constructor, see comment there + assert cluster_options.max_clusters is None + write_clusters_and_bins( + vamb_options, + cluster_options.binsplitter, + str(vamb_options.out_dir.joinpath("aae_y_clusters")), + vamb_options.out_dir.joinpath("bins"), + comp_options.path, + clusters_y_dict, + comp_metadata.identifiers, # type:ignore + comp_metadata.lengths, # type:ignore + logfile, + ) def parse_mmseqs_taxonomy( @@ -1224,12 +1156,14 @@ def extract_and_filter_data( vamb_options: VambOptions, comp_options: CompositionOptions, abundance_options: AbundanceOptions, + binsplitter: vamb.vambtools.BinSplitter, logfile: IO[str], ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: composition, abundance = load_composition_and_abundance( vamb_options=vamb_options, comp_options=comp_options, abundance_options=abundance_options, + binsplitter=binsplitter, logfile=logfile, ) @@ -1255,6 +1189,7 @@ def run_taxonomy_predictor( vamb_options=vamb_options, comp_options=comp_options, abundance_options=abundance_options, + binsplitter=vamb.vambtools.BinSplitter.inert_splitter(), logfile=logfile, ) assert taxonomy_options.taxonomy_path is not None @@ -1287,6 +1222,7 @@ def run_vaevae( vamb_options=vamb_options, comp_options=comp_options, abundance_options=abundance_options, + binsplitter=cluster_options.binsplitter, logfile=logfile, ) @@ -1411,38 +1347,17 @@ def run_vaevae( np.save(LATENT_PATH, latent_both) # Cluster, save tsv file - clusterspath = vamb_options.out_dir.joinpath("vaevae_clusters.tsv") - cluster( + cluster_and_write_files( + vamb_options, cluster_options, - clusterspath, + str(vamb_options.out_dir.joinpath("vaevae_clusters")), + vamb_options.out_dir.joinpath("bins"), + comp_options.path, latent_both, - contignames, - lengths, - vamb_options, + contignames, # type:ignore + lengths, # type:ignore logfile, - "vaevae_", ) - log("VAEVAE latent clustered", logfile, 1) - - del latent_both - fin_cluster_latent = time.time() - - if vamb_options.min_fasta_output_size is not None: - path = comp_options.path - assert isinstance(path, FASTAPath) - write_fasta( - vamb_options.out_dir, - clusterspath, - path, - contignames, - lengths, - vamb_options.min_fasta_output_size, - logfile, - separator=cluster_options.binsplit_separator, - ) - - writing_bins_time = round(time.time() - fin_cluster_latent, 2) - log(f"VAEVAE bins written in {writing_bins_time} seconds.", logfile, 1) def run_reclustering( @@ -1457,6 +1372,7 @@ def run_reclustering( vamb_options=vamb_options, comp_options=comp_options, abundance_options=abundance_options, + binsplitter=reclustering_options.binsplitter, logfile=logfile, ) @@ -1484,45 +1400,21 @@ def run_reclustering( predictions_path=predictions_path, ) - cluster_dict = defaultdict(set) + cluster_dict: defaultdict[str, set[str]] = defaultdict(set) for k, v in zip(reclustered, composition.metadata.identifiers): cluster_dict[k].add(v) - clusterspath = vamb_options.out_dir.joinpath("clusters_reclustered.tsv") - if reclustering_options.binsplit_separator is not None: - maybe_split = vamb.vambtools.binsplit( - cluster_dict.items(), reclustering_options.binsplit_separator - ) - else: - maybe_split = cluster_dict.items() - with open(clusterspath, "w") as clustersfile: - clusternumber, ncontigs = vamb.vambtools.write_clusters( - clustersfile, - maybe_split, - rename=False, - cluster_prefix="recluster", - ) - - print("", file=logfile) - log(f"Clustered {ncontigs} contigs in {clusternumber} bins", logfile, 1) - - fin_cluster_latent = time.time() - - if vamb_options.min_fasta_output_size is not None: - path = comp_options.path - assert isinstance(path, FASTAPath) - write_fasta( - vamb_options.out_dir, - clusterspath, - path, - composition.metadata.identifiers, - composition.metadata.lengths, - vamb_options.min_fasta_output_size, - logfile, - separator=reclustering_options.binsplit_separator, - ) - writing_bins_time = round(time.time() - fin_cluster_latent, 2) - log(f"Reclustered bins written in {writing_bins_time} seconds.", logfile, 1) + write_clusters_and_bins( + vamb_options, + reclustering_options.binsplitter, + str(vamb_options.out_dir.joinpath("clusters_reclustered")), + vamb_options.out_dir.joinpath("bins"), + comp_options.path, + cluster_dict, + composition.metadata.identifiers, # type:ignore + composition.metadata.lengths, # type:ignore + logfile, + ) class BasicArguments(object): @@ -1531,7 +1423,10 @@ class BasicArguments(object): def __init__(self, args): self.args = args self.comp_options = CompositionOptions( - self.args.fasta, self.args.composition, self.args.minlength + self.args.fasta, + self.args.composition, + self.args.minlength, + args.warn_on_few_seqs, ) self.abundance_options = AbundanceOptions( self.args.bampaths, @@ -1555,8 +1450,21 @@ def run(self): torch.set_num_threads(self.vamb_options.n_threads) try_make_dir(self.vamb_options.out_dir) with open(self.vamb_options.out_dir.joinpath(self.LOGS_PATH), "w") as logfile: + begintime = time.time() + log( + "Starting Vamb version " + ".".join(map(str, vamb.__version__)), logfile + ) + log("Date and time is " + str(datetime.datetime.now()), logfile, 1) + log("Random seed is " + str(self.vamb_options.seed), logfile, 1) + self.run_inner(logfile) + log( + f"\nCompleted Vamb in {round(time.time() - begintime, 2)} seconds.", + logfile, + 0, + ) + class TaxometerArguments(BasicArguments): def __init__(self, args): @@ -1604,7 +1512,6 @@ def __init__(self, args): self.cluster_options = ClusterOptions( args.window_size, args.min_successes, - args.min_cluster_size, args.max_clusters, args.binsplit_separator, ) @@ -2015,14 +1922,6 @@ def add_clustering_arguments(subparser): default=15, help="minimum success in window [15]", ) - clusto.add_argument( - "-i", - dest="min_cluster_size", - metavar="", - type=int, - default=1, - help="minimum cluster size [1]", - ) clusto.add_argument( "-c", dest="max_clusters", @@ -2036,8 +1935,12 @@ def add_clustering_arguments(subparser): dest="binsplit_separator", metavar="", type=str, + # This means: None is not passed, "" if passed but empty, e.g. "-o -c 5" + # otherwise a string. default=None, - help="binsplit separator [None = no split]", + const="", + nargs="?", + help="binsplit separator [C] (pass empty string to disable)", ) return subparser @@ -2327,8 +2230,10 @@ def main(): args = parser.parse_args() if args.subcommand == TAXOMETER: + args.warn_on_few_seqs = True runner = TaxometerArguments(args) elif args.subcommand == BIN: + args.warn_on_few_seqs = True if args.model_subcommand is None: vaevae_parserbin_parser.print_help() sys.exit(1) @@ -2339,6 +2244,8 @@ def main(): } runner = classes_map[args.model_subcommand](args) elif args.subcommand == RECLUSTER: + # Uniquely, the reclustering cannot overfit, so we don't need this warning + args.warn_on_few_seqs = False runner = ReclusteringArguments(args) else: # There are no more subcommands diff --git a/vamb/parsecontigs.py b/vamb/parsecontigs.py index fe08f062..5d3deb15 100644 --- a/vamb/parsecontigs.py +++ b/vamb/parsecontigs.py @@ -215,26 +215,4 @@ def from_file( _np.array(mask, dtype=bool), minlength, ) - - if len(metadata.lengths) < 20_000: - message = ( - f"WARNING: Parsed only {len(metadata.lengths)} sequences from FASTA file. " - "We normally expect 20,000 sequences or more to prevent overfitting. " - "As a deep learning model, VAEs are prone to overfitting with too few sequences. " - "You may want to bin more samples as a time, lower the beta parameter, " - "or use a different binner altogether." - ) - _vambtools.log_and_warn(message, logfile=logfile) - - # Warn the user if any contigs have been observed, which is smaller - # than the threshold. - if not _np.all(metadata.mask): - message = ( - f"WARNING: The minimum sequence length has been set to {minlength}, but a contig with " - f"length {minimum_seen_length} was seen. " - "Better results are obtained if the sequence file is filtered to the minimum " - "sequence length before mapping." - ) - _vambtools.log_and_warn(message, logfile=logfile) - return cls(metadata, tnfs_arr) diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 822409b4..95a69f48 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -1,6 +1,5 @@ __doc__ = "Various classes and functions Vamb uses internally." -import os as _os import gzip as _gzip import bz2 as _bz2 import lzma as _lzma @@ -10,12 +9,146 @@ import collections as _collections from itertools import zip_longest from hashlib import md5 as _md5 -from collections.abc import Iterable, Iterator, Generator +from collections.abc import Iterable, Iterator from typing import Optional, IO, Union from pathlib import Path import warnings +class BinSplitter: + """ + The binsplitter can be either + * Instantiated with an explicit option, in which case `is_default` is False, + and `splitter` is None if the user explicitly opted out of binsplitting by + passing in an empty string, else a `str`, the string explicitly asked for + * Instantiated by default, in which case `is_default` is `True`, and `splitter` + is `_DEFAULT_SPLITTER` + + The `initialize` function checks the validity of the binsplitter on the set of + identifiers: + * If the binsplitter sep is explicitly `None`, do nothing + * If the binsplitter is default and the separator is not found, warn the user, + then set the separator to `None` (disabling binsplitting) + * If the binsplitter is explicitly a string, check this string occurs in all + identifiers, else error early. + """ + + _DEFAULT_SPLITTER = "C" + __slots__ = ["is_default", "splitter"] + + def __init__(self, binsplitter: Optional[str]): + if binsplitter is None: + self.is_default = True + self.splitter = self._DEFAULT_SPLITTER + + else: + self.is_default = False + if len(binsplitter) == 0: + self.splitter = None + else: + self.splitter = binsplitter + + @classmethod + def inert_splitter(cls): + return cls("") + + def initialize(self, identifiers: Iterable[str]): + separator = self.splitter + if separator is None: + return None + message = ( + 'Binsplit separator (option `-o`) {imexplicit} passed as "{separator}", ' + 'but sequence identifier "{identifier}" does not contain this separator, ' + "or contains it at the very start or end.\n" + "A binsplit separator X implies that every sequence identifier is formatted as\n" + "[sample identifier][X][sequence identifier], e.g. a binsplit separator of 'C' " + "means that 'S1C19' and '7C11' are valid identifiers." + ) + + if not self.is_default: + for identifier in identifiers: + (front, _, rest) = identifier.partition(separator) + if not front or not rest: + log_and_raise( + message.format( + imexplicit="explicitly", + separator=separator, + identifier=identifier, + ) + ) + else: + for identifier in identifiers: + (front, _, rest) = identifier.partition(separator) + if not front or not rest: + message += "\nSkipping binsplitting." + log_and_warn( + message.format( + imexplicit="implicitly", + separator=separator, + identifier=identifier, + ) + ) + self.splitter = None + break + + def split_bin( + self, + binname: str, + identifiers: Iterable[str], + ) -> Iterable[tuple[str, set[str]]]: + "Split a single bin by the prefix of the headers" + if self.splitter is None: + yield (binname, set(identifiers)) + return None + else: + by_sample: dict[str, set[str]] = _collections.defaultdict(set) + for identifier in identifiers: + sample, _, rest = identifier.partition(self.splitter) + + if not rest or not sample: + raise KeyError( + f"Separator '{self.splitter}' not in sequence identifier, or is at the very start or end of identifier: '{identifier}'" + ) + + by_sample[sample].add(identifier) + + for sample, splitheaders in by_sample.items(): + newbinname = f"{sample}{self.splitter}{binname}" + yield newbinname, splitheaders + + def binsplit( + self, + clusters: Iterable[tuple[str, Iterable[str]]], + ) -> Iterable[tuple[str, set[str]]]: + """Splits a set of clusters by the prefix of their names. + The separator is a string which separated prefix from postfix of contignames. The + resulting split clusters have the prefix and separator prepended to them. + + clusters can be an iterator, in which case this function returns an iterator, or a dict + with contignames: set_of_contignames pair, in which case a dict is returned. + + Example: + >>> clusters = {"bin1": {"s1-c1", "s1-c5", "s2-c1", "s2-c3", "s5-c8"}} + >>> binsplit(clusters, "-") + {'s2-bin1': {'s1-c1', 's1-c3'}, 's1-bin1': {'s1-c1', 's1-c5'}, 's5-bin1': {'s1-c8'}} + """ + for binname, headers in clusters: + for newbinname, splitheaders in self.split_bin(binname, headers): + yield newbinname, splitheaders + + def log_string(self) -> str: + if not self.is_default: + if self.splitter is None: + return "Explicitly passed as empty (no binsplitting)" + else: + return '"{self.splitter}"' + else: + if self.splitter is None: + return "Defaulting to 'C', but disabled due to incompatible identifiers" + else: + return "Defaulting to 'C'" + + def showwarning_override(message, category, filename, lineno, file=None, line=None): print(str(message) + "\n", file=file) @@ -440,67 +573,18 @@ def verify_refhash( def write_clusters( - filehandle: IO[str], + unsplit_io: IO[str], clusters: Iterable[tuple[str, set[str]]], - max_clusters: Optional[int] = None, - min_size: int = 1, - header: Optional[str] = None, - rename: bool = True, - cluster_prefix: str = "", ) -> tuple[int, int]: - """Writes clusters to an open filehandle. - Inputs: - filehandle: An open filehandle that can be written to - clusters: An iterator generated by function `clusters` or a dict - max_clusters: Stop printing after this many clusters [None] - min_size: Don't output clusters smaller than N contigs - header: Commented one-line header to add - rename: Rename clusters to "cluster_1", "cluster_2" etc. - cluster_prefix: prepend a tag to identify which model produced the clusters (vae,aae_l, aae_y) - - Outputs: - clusternumber: Number of clusters written - ncontigs: Number of contigs written - """ - - if not hasattr(filehandle, "writable") or not filehandle.writable(): - raise ValueError("Filehandle must be a writable file") - - if max_clusters is not None and max_clusters < 1: - raise ValueError("max_clusters must None or at least 1, not {max_clusters}") - - if header is not None and len(header) > 0: - if "\n" in header: - raise ValueError("Header cannot contain newline") - - if header[0] != "#": - header = "# " + header + n_clusters = 0 + n_contigs = 0 + for cluster_name, contig_names in clusters: + n_clusters += 1 + n_contigs += len(contig_names) + for contig_name in contig_names: + print(cluster_name, contig_name, sep="\t", file=unsplit_io) - print(header, file=filehandle) - - clusternumber = 0 - ncontigs = 0 - - for clustername, contigs in clusters: - if len(contigs) < min_size: - continue - - if rename: - clustername = cluster_prefix + "cluster_" + str(clusternumber + 1) - else: - clustername = cluster_prefix + str(clusternumber + 1) - - for contig in contigs: - print(clustername, contig, sep="\t", file=filehandle) - filehandle.flush() - - clusternumber += 1 - ncontigs += len(contigs) - - if clusternumber == max_clusters: - break - - return clusternumber, ncontigs + return (n_clusters, n_contigs) def read_clusters(filehandle: Iterable[str], min_size: int = 1) -> dict[str, set[str]]: @@ -528,24 +612,36 @@ def read_clusters(filehandle: Iterable[str], min_size: int = 1) -> dict[str, set return contigsof_dict +def check_is_creatable_file_path(path: Path) -> None: + if path.exists(): + raise FileExistsError(path) + if not path.parent.is_dir(): + raise NotADirectoryError(path.parent) + + +def create_dir_if_not_existing(path: Path) -> None: + if path.is_dir(): + return None + if path.is_file(): + raise FileExistsError(path) + if not path.parent.is_dir(): + raise NotADirectoryError(path.parent) + path.mkdir(exist_ok=True) + + def write_bins( - directory: Union[str, Path], + directory: Path, bins: dict[str, set[str]], fastaio: Iterable[bytes], maxbins: Optional[int] = 250, - minsize: int = 0, - separator: Optional[str] = None, ): """Writes bins as FASTA files in a directory, one file per bin. Inputs: directory: Directory to create or put files in - bins: dict[str: set[str]] (can be loaded from - clusters.tsv using vamb.cluster.read_clusters) + bins: dict[str: set[str]] (can be loaded from clusters.tsv using vamb.cluster.read_clusters) fastaio: bytes iterator containing FASTA file with all sequences maxbins: None or else raise an error if trying to make more bins than this [250] - minsize: Minimum number of nucleotides in cluster to be output [0] - separator: string that separates the contig/cluster name from the sample ; i.e. sample_id_separator_contig_name/cluster_name Output: None """ @@ -555,67 +651,32 @@ def write_bins( if maxbins is not None and len(bins) > maxbins: raise ValueError(f"{len(bins)} bins exceed maxbins of {maxbins}") - # Check that the directory is not a non-directory file, - # and that its parent directory indeed exists - abspath = _os.path.abspath(directory) - parentdir = _os.path.dirname(abspath) - - if parentdir != "" and not _os.path.isdir(parentdir): - raise NotADirectoryError(parentdir) - - if _os.path.isfile(abspath): - raise FileExistsError(abspath) + create_dir_if_not_existing(directory) - if minsize < 0: - raise ValueError("Minsize must be nonnegative") + keep: set[str] = set() + for i in bins.values(): + keep.update(i) - byteslen_by_id: dict[str, tuple[bytes, int]] = dict() + bytes_by_id: dict[str, bytes] = dict() for entry in byte_iterfasta(fastaio): - byteslen_by_id[entry.identifier] = ( - _gzip.compress(entry.format().encode(), compresslevel=1), - len(entry), - ) - - # Make the directory if it does not exist - if it does, do nothing - try: - _os.mkdir(directory) - except FileExistsError: - pass + if entry.identifier in keep: + bytes_by_id[entry.identifier] = _gzip.compress( + entry.format().encode(), compresslevel=1 + ) # Now actually print all the contigs to files for binname, contigs in bins.items(): - size = 0 - if separator is not None: - binsample = next(iter(contigs)).split(separator)[0] - else: - binsample = None for contig in contigs: - byteslen = byteslen_by_id.get(contig) - if byteslen is None: + byts = bytes_by_id.get(contig) + if byts is None: raise IndexError( f'Contig "{contig}" in bin missing from input FASTA file' ) - size += byteslen[1] - - if size < minsize: - continue - - # Split bin files into sample dirs - if binsample is not None: - bin_dir = _os.path.join(directory, binsample) - try: - _os.mkdir(bin_dir) - except FileExistsError: - pass - else: - bin_dir = directory # Print bin to file - filename = _os.path.join(bin_dir, binname + ".fna") - - with open(filename, "wb") as file: + with open(directory.joinpath(binname + ".fna"), "wb") as file: for contig in contigs: - file.write(_gzip.decompress(byteslen_by_id[contig][0])) + file.write(_gzip.decompress(bytes_by_id[contig])) file.write(b"\n") @@ -712,50 +773,3 @@ def open_file_iterator(paths: Iterable[Path]) -> Iterable[Reader]: for path in paths: with Reader(path) as io: yield io - - -def _split_bin( - binname: str, - headers: Iterable[str], - separator: str, - bysample: _collections.defaultdict[str, set[str]] = _collections.defaultdict(set), -) -> Generator[tuple[str, set[str]], None, None]: - "Split a single bin by the prefix of the headers" - - bysample.clear() - for header in headers: - if not isinstance(header, str): # type: ignore - raise TypeError( - f"Can only split named sequences, not of type {type(header)}" - ) - - sample, _, identifier = header.partition(separator) - - if not identifier: - raise KeyError(f"Separator '{separator}' not in sequence label: '{header}'") - - bysample[sample].add(header) - - for sample, splitheaders in bysample.items(): - newbinname = f"{sample}{separator}{binname}" - yield newbinname, splitheaders - - -def binsplit( - clusters: Iterable[tuple[str, Iterable[str]]], separator: str -) -> Generator[tuple[str, set[str]], None, None]: - """Splits a set of clusters by the prefix of their names. - The separator is a string which separated prefix from postfix of contignames. The - resulting split clusters have the prefix and separator prepended to them. - - clusters can be an iterator, in which case this function returns an iterator, or a dict - with contignames: set_of_contignames pair, in which case a dict is returned. - - Example: - >>> clusters = {"bin1": {"s1-c1", "s1-c5", "s2-c1", "s2-c3", "s5-c8"}} - >>> binsplit(clusters, "-") - {'s2-bin1': {'s1-c1', 's1-c3'}, 's1-bin1': {'s1-c1', 's1-c5'}, 's5-bin1': {'s1-c8'}} - """ - for binname, headers in clusters: - for newbinname, splitheaders in _split_bin(binname, headers, separator): - yield newbinname, splitheaders diff --git a/workflow_avamb/avamb.snake.conda.smk b/workflow_avamb/avamb.snake.conda.smk index 99d0b38c..c80cb046 100644 --- a/workflow_avamb/avamb.snake.conda.smk +++ b/workflow_avamb/avamb.snake.conda.smk @@ -332,9 +332,9 @@ rule run_avamb: abundance=os.path.join(OUTDIR,"abundance.npz") output: outdir_avamb=directory(os.path.join(OUTDIR,"avamb")), - clusters_aae_z=os.path.join(OUTDIR,"avamb/aae_z_clusters.tsv"), - clusters_aae_y=os.path.join(OUTDIR,"avamb/aae_y_clusters.tsv"), - clusters_vamb=os.path.join(OUTDIR,"avamb/vae_clusters.tsv"), + clusters_aae_z=os.path.join(OUTDIR,"avamb/aae_z_clusters_split.tsv"), + clusters_aae_y=os.path.join(OUTDIR,"avamb/aae_y_clusters_split.tsv"), + clusters_vamb=os.path.join(OUTDIR,"avamb/vae_clusters_split.tsv"), contignames=os.path.join(OUTDIR,"avamb/contignames"), contiglenghts=os.path.join(OUTDIR,"avamb/lengths.npz") params: @@ -473,9 +473,9 @@ rule run_drep_manual_vamb_z_y: cluster_score_dict_path_avamb=os.path.join(OUTDIR,"tmp/cs_d_avamb.json"), contignames=os.path.join(OUTDIR,"avamb/contignames"), contiglengths=os.path.join(OUTDIR,"avamb/lengths.npz"), - clusters_aae_z=os.path.join(OUTDIR,"avamb/aae_z_clusters.tsv"), - clusters_aae_y=os.path.join(OUTDIR,"avamb/aae_y_clusters.tsv"), - clusters_vamb=os.path.join(OUTDIR,"avamb/vae_clusters.tsv") + clusters_aae_z=os.path.join(OUTDIR,"avamb/aae_z_clusters_split.tsv"), + clusters_aae_y=os.path.join(OUTDIR,"avamb/aae_y_clusters_split.tsv"), + clusters_vamb=os.path.join(OUTDIR,"avamb/vae_clusters_split.tsv") output: clusters_avamb_manual_drep=os.path.join(OUTDIR,"tmp/avamb_manual_drep_clusters.tsv")