Skip to content

Commit

Permalink
Refactor binsplitting
Browse files Browse the repository at this point in the history
This is a larger change which overhauls how binsplitting is done, and, as a
consequence, reworks some of the overall workflow in `__main__.py`.
The PR is intended to address the following problems:
* Before, we only output either the binsplit clusters, or the unsplit clusters.
  This is problematic, because we know the binsplit clusters are the best ones,
  so we would like to output these. However, the unsplit ones contain important
  information about the source cluster, which powerusers need to be able to
  recover.
  	- Now, we output both `_split.tsv` and `_unsplit.tsv` files, if
  	  binsplitting takes place.
* Before, we defaulted to no binsplitting, even as we know it was inferior
    - Now, `-o C` is default.
* Before, if a user passed in a wrong binsplit separator, Vamb would not error
  until the clustering step, and the error message would be inscrutable
    - Now, error already when parsing the contigs, EXCEPT if the binsplit sep
      has defaulted to 'C', in which case binsplitting is disabled, and the user
      is warned
    - The error message is significantly improved and more explanatory
* Before, the logic of where binsplitting happened was ad-hoc, and scattered
  all over the place. For example, binsplitting took place during cluster
  writing, during bin writing, during benchmarking, during clustering itself,
  and immediately after clustering. It was also implemented multiple places.
    - Now, create a `BinSplitter` class responsible for binsplitting.
      The writer functions and loader functions do not binsplit.
    - Now, binsplitting mostly takes place immediately before writing the
      split clusters meaning the clusters are unambiguously unsplit for the
      majority of the program
  • Loading branch information
jakobnissen committed Nov 14, 2023
1 parent ff5eac6 commit 15638ff
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 567 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cli_vamb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ jobs:
cat outdir_taxometer/log.txt
- name: Run k-means reclustering
run: |
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --rpkm abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --rpkm abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000
ls -la outdir_recluster
cat outdir_recluster/log.txt
32 changes: 5 additions & 27 deletions test/test_parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest
import random
import numpy as np
import warnings

import testtools
from vamb.parsecontigs import Composition, CompositionMetaData
Expand Down Expand Up @@ -41,26 +40,9 @@ def test_unique_names(self):
1000,
)

# Does not warn
def test_nowarn(self):
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
Composition.from_file(self.large_io, minlength=250)

def test_warns_n_contigs(self):
with self.assertWarns(UserWarning):
Composition.from_file(self.io, minlength=250)

def test_warns_minlength(self):
with self.assertWarns(UserWarning):
Composition.from_file(self.large_io, minlength=275)

def test_filter_minlength(self):
minlen = 500

with self.assertWarns(UserWarning):
composition = Composition.from_file(self.io, minlength=450)

composition = Composition.from_file(self.io, minlength=450)
md = composition.metadata
hash1 = md.refhash

Expand Down Expand Up @@ -99,8 +81,7 @@ def test_minlength(self):
Composition.from_file(self.io, minlength=3)

def test_properties(self):
with self.assertWarns(UserWarning):
composition = Composition.from_file(self.io, minlength=420)
composition = Composition.from_file(self.io, minlength=420)
passed = list(filter(lambda x: len(x.sequence) >= 420, self.records))

self.assertEqual(composition.nseqs, len(composition.metadata.identifiers))
Expand All @@ -122,8 +103,7 @@ def test_properties(self):

def test_save_load(self):
buf = io.BytesIO()
with self.assertWarns(UserWarning):
composition_1 = Composition.from_file(self.io)
composition_1 = Composition.from_file(self.io)
md1 = composition_1.metadata
composition_1.save(buf)
buf.seek(0)
Expand Down Expand Up @@ -153,10 +133,8 @@ def test_windows_newlines(self):

buf1.seek(0)
buf2.seek(0)
with self.assertWarns(UserWarning):
comp1 = Composition.from_file(buf1)
with self.assertWarns(UserWarning):
comp2 = Composition.from_file(buf2)
comp1 = Composition.from_file(buf1)
comp2 = Composition.from_file(buf2)

self.assertEqual(comp1.metadata.refhash, comp2.metadata.refhash)
self.assertTrue(np.all(comp1.matrix == comp2.matrix))
5 changes: 2 additions & 3 deletions test/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ def setUp(self):
self.io.seek(0)

def test_runs(self):
with self.assertRaises(UserWarning):
comp = vamb.parsecontigs.Composition.from_file(self.io)
self.assertIsInstance(comp, vamb.parsecontigs.Composition)
comp = vamb.parsecontigs.Composition.from_file(self.io)
self.assertIsInstance(comp, vamb.parsecontigs.Composition)

if TEST_UNSTABLE_HASHES:

Expand Down
127 changes: 70 additions & 57 deletions test/test_vambtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
import numpy as np
import string
import torch
import pathlib
import shutil

import vamb
from vamb.vambtools import BinSplitter
import testtools

PARENTDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down Expand Up @@ -220,7 +223,7 @@ def test_bad_files(self):
# String input
with self.assertRaises(TypeError):
data = ">abc\nTAG\na\nAC".splitlines()
list(vamb.vambtools.byte_iterfasta(data))
list(vamb.vambtools.byte_iterfasta(data)) # type:ignore

# Various correct formats
def test_good_files(self):
Expand Down Expand Up @@ -379,16 +382,48 @@ class TestBinSplit(unittest.TestCase):
("s12-bin2", {"s12-c0"}),
]

def test_inert(self):
self.assertEqual(
BinSplitter("").splitter, BinSplitter.inert_splitter().splitter
)

def test_split(self):
self.assertEqual(list(vamb.vambtools.binsplit(self.before, "-")), self.after)
self.assertEqual(list(BinSplitter("-").binsplit(self.before)), self.after)

def test_badsep(self):
with self.assertRaises(KeyError):
list(vamb.vambtools.binsplit(self.before, "2"))
list(BinSplitter("2").binsplit(self.before))

def test_badtype(self):
with self.assertRaises(TypeError):
list(vamb.vambtools.binsplit([(1, [2])], ""))
with self.assertRaises(Exception):
list(BinSplitter("x").binsplit([(1, [2])])) # type:ignore

def test_nosplit(self):
self.assertEqual(
list(BinSplitter("").binsplit(self.before)),
[(k, set(s)) for (k, s) in self.before],
)

def test_initialize(self):
# Nothing happens to an inert splitter
b = BinSplitter.inert_splitter()
s = b.splitter
b.initialize([""])
self.assertEqual(s, b.splitter)

b = BinSplitter("X")
with self.assertRaises(ValueError):
b.initialize(["AXC", "S1C2"])

b = BinSplitter(None)
with self.assertWarns(UserWarning):
b.initialize(["S1C2", "ABC"])

b = BinSplitter(None)
b.initialize(["S1C2", "KMCPLK"])

b = BinSplitter("XYZ")
b.initialize(["ABXYZCD", "KLMXYZA"])


class TestConcatenateFasta(unittest.TestCase):
Expand Down Expand Up @@ -459,8 +494,8 @@ def setUp(self):
self.io.truncate(0)
self.io.seek(0)

def linesof(self, str):
return list(filter(lambda x: not x.startswith("#"), str.splitlines()))
def linesof(self, string: str):
return list(filter(lambda x: not x.startswith("#"), string.splitlines()))

def conforms(self, str, clusters):
lines = self.linesof(str)
Expand Down Expand Up @@ -488,46 +523,29 @@ def conforms(self, str, clusters):
self.assertEqual(read_names, printed_names)

def test_not_writable(self):
buf = io.BufferedReader(io.BytesIO(b""))
buf = io.BufferedReader(io.BytesIO(b"")) # type:ignore
with self.assertRaises(ValueError):
vamb.vambtools.write_clusters(buf, self.test_clusters)

def test_invalid_max_clusters(self):
with self.assertRaises(ValueError):
vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=0)

with self.assertRaises(ValueError):
vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=-11)

def test_header_has_newline(self):
with self.assertRaises(ValueError):
vamb.vambtools.write_clusters(self.io, self.test_clusters, header="foo\n")
vamb.vambtools.write_clusters(buf, self.test_clusters) # type:ignore

def test_normal(self):
vamb.vambtools.write_clusters(self.io, self.test_clusters, header="someheader")
self.assertTrue(self.io.getvalue().startswith("# someheader"))
vamb.vambtools.write_clusters(self.io, self.test_clusters)
self.conforms(self.io.getvalue(), self.test_clusters)

def test_max_clusters(self):
vamb.vambtools.write_clusters(self.io, self.test_clusters, max_clusters=2)
vamb.vambtools.write_clusters(self.io, self.test_clusters[:2])
lines = self.linesof(self.io.getvalue())
self.assertEqual(len(lines), 9)
self.conforms(self.io.getvalue(), self.test_clusters[:2])

def test_min_size(self):
vamb.vambtools.write_clusters(self.io, self.test_clusters, min_size=5)
lines = self.linesof(self.io.getvalue())
self.assertEqual(len(lines), 6)
self.conforms(self.io.getvalue(), self.test_clusters[1:2])


class TestWriteBins(unittest.TestCase):
file = io.BytesIO()
N_BINS = 10
minsize = 5 * 175 # mean of bin size
dirname = os.path.join(
tempfile.gettempdir(),
"".join(random.choices(string.ascii_letters + string.digits, k=10)),
dir = pathlib.Path(
os.path.join(
tempfile.gettempdir(),
"".join(random.choices(string.ascii_letters + string.digits, k=10)),
)
)

@classmethod
Expand All @@ -552,71 +570,66 @@ def setUp(self):

def tearDown(self):
try:
os.rmdir(self.dirname)
shutil.rmtree(self.dir)
except FileNotFoundError:
pass

def test_bad_params(self):
# Too many bins for maxbins
with self.assertRaises(ValueError):
vamb.vambtools.write_bins(
self.dirname, self.bins, self.file, maxbins=self.N_BINS - 1
)

# Negative minsize
with self.assertRaises(ValueError):
vamb.vambtools.write_bins(
self.dirname, self.bins, self.file, maxbins=self.N_BINS + 1, minsize=-1
self.dir, self.bins, self.file, maxbins=self.N_BINS - 1
)

# Parent does not exist
with self.assertRaises(NotADirectoryError):
vamb.vambtools.write_bins(
"svogew/foo", self.bins, self.file, maxbins=self.N_BINS + 1
pathlib.Path("svogew/foo"),
self.bins,
self.file,
maxbins=self.N_BINS + 1,
)

# Target file already exists
# Target is an existing file
with self.assertRaises(FileExistsError):
with tempfile.NamedTemporaryFile() as file:
vamb.vambtools.write_bins(
file.name, self.bins, self.file, maxbins=self.N_BINS + 1
pathlib.Path(file.name),
self.bins,
self.file,
maxbins=self.N_BINS + 1,
)

# One contig missing from fasta dict
with self.assertRaises(IndexError):
bins = {k: v.copy() for k, v in self.bins.items()}
next(iter(bins.values())).add("a_new_bin_which_does_not_exist")
vamb.vambtools.write_bins(
self.dirname, bins, self.file, maxbins=self.N_BINS + 1
self.dir, bins, self.file, maxbins=self.N_BINS + 1
)

def test_round_trip(self):
with tempfile.TemporaryDirectory() as dir:
vamb.vambtools.write_bins(
dir, self.bins, self.file, maxbins=self.N_BINS, minsize=self.minsize
pathlib.Path(dir),
self.bins,
self.file,
maxbins=self.N_BINS,
)

reconstructed_bins: dict[str, set[str]] = dict()
for filename in os.listdir(dir):
with open(os.path.join(dir, filename), "rb") as file:
entries = list(vamb.vambtools.byte_iterfasta(file))
if sum(map(len, entries)) < self.minsize:
continue
binname = filename[:-4]
reconstructed_bins[binname] = set()
for entry in entries:
reconstructed_bins[binname].add(entry.identifier)

filtered_bins = {
k: v
for (k, v) in self.bins.items()
if sum(len(self.seqs[vi]) for vi in v) >= self.minsize
}

# Same bins
self.assertEqual(len(filtered_bins), len(reconstructed_bins))
self.assertEqual(len(self.bins), len(reconstructed_bins))
self.assertEqual(
sum(map(len, filtered_bins.values())),
sum(map(len, self.bins.values())),
sum(map(len, reconstructed_bins.values())),
)
self.assertEqual(filtered_bins, reconstructed_bins)
self.assertEqual(self.bins, reconstructed_bins)
Loading

0 comments on commit 15638ff

Please sign in to comment.