diff --git a/test/test_encode.py b/test/test_encode.py index 6a2850e9..2897270c 100644 --- a/test/test_encode.py +++ b/test/test_encode.py @@ -157,10 +157,6 @@ def test_samples_too_small(self): t = self.tnfs.copy() l = self.lens.copy() - with self.assertWarns(UserWarning): - dl = vamb.encode.make_dataloader(r, t, l, batchsize=256) - vae.trainmodel(dl, batchsteps=None, nepochs=2) - def test_loss_falls(self): vae = vamb.encode.VAE(self.rpkm.shape[1]) rpkm_copy = self.rpkm.copy() diff --git a/test/test_parsecontigs.py b/test/test_parsecontigs.py index 20adbe50..bd78a63a 100644 --- a/test/test_parsecontigs.py +++ b/test/test_parsecontigs.py @@ -2,6 +2,7 @@ import unittest import random import numpy as np +import warnings import testtools from vamb.parsecontigs import Composition, CompositionMetaData @@ -9,6 +10,7 @@ class TestReadContigs(unittest.TestCase): records = [] + large_io = io.BytesIO() io = io.BytesIO() @classmethod @@ -21,8 +23,14 @@ def setUpClass(cls): cls.io.write(i.format().encode()) cls.io.write(b"\n") + for i in range(25_000): + record = testtools.make_randseq(rng, 250, 300) + cls.large_io.write(record.format().encode()) + cls.large_io.write(b"\n") + def setUp(self): self.io.seek(0) + self.large_io.seek(0) def test_unique_names(self): with self.assertRaises(ValueError): @@ -33,9 +41,22 @@ def test_unique_names(self): 1000, ) + # Does not warn + def test_nowarn(self): + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + Composition.from_file(self.large_io, minlength=250) + + def test_warns_n_contigs(self): + with self.assertWarns(UserWarning): + Composition.from_file(self.io, minlength=250) + def test_filter_minlength(self): minlen = 500 - composition = Composition.from_file(self.io, minlength=450) + + with self.assertWarns(UserWarning): + composition = Composition.from_file(self.io, minlength=450) + md = composition.metadata hash1 = md.refhash @@ -74,7 +95,8 @@ def test_minlength(self): Composition.from_file(self.io, minlength=3) def test_properties(self): - composition = Composition.from_file(self.io, minlength=420) + with self.assertWarns(UserWarning): + composition = Composition.from_file(self.io, minlength=420) passed = list(filter(lambda x: len(x.sequence) >= 420, self.records)) self.assertEqual(composition.nseqs, len(composition.metadata.identifiers)) @@ -96,7 +118,8 @@ def test_properties(self): def test_save_load(self): buf = io.BytesIO() - composition_1 = Composition.from_file(self.io) + with self.assertWarns(UserWarning): + composition_1 = Composition.from_file(self.io) md1 = composition_1.metadata composition_1.save(buf) buf.seek(0) @@ -126,8 +149,10 @@ def test_windows_newlines(self): buf1.seek(0) buf2.seek(0) - comp1 = Composition.from_file(buf1) - comp2 = Composition.from_file(buf2) + with self.assertWarns(UserWarning): + comp1 = Composition.from_file(buf1) + with self.assertWarns(UserWarning): + comp2 = Composition.from_file(buf2) self.assertEqual(comp1.metadata.refhash, comp2.metadata.refhash) self.assertTrue(np.all(comp1.matrix == comp2.matrix)) diff --git a/vamb/__main__.py b/vamb/__main__.py index 21d7b90e..69a5a02a 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -524,7 +524,7 @@ def calc_tnf( log(f"Loading data from FASTA file {path}", logfile, 1) with vamb.vambtools.Reader(str(path)) as file: composition = vamb.parsecontigs.Composition.from_file( - file, minlength=options.min_contig_length + file, minlength=options.min_contig_length, logfile=logfile ) composition.save(outdir.joinpath("composition.npz")) diff --git a/vamb/encode.py b/vamb/encode.py index 845b90d2..15838b39 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -9,7 +9,6 @@ from torch import Tensor from torch import nn as _nn from math import log as _log -import warnings __doc__ = """Encode a depths matrix and a tnf matrix to latent representation. @@ -87,14 +86,6 @@ def make_dataloader( if not (rpkm.dtype == tnf.dtype == _np.float32): raise ValueError("TNF and RPKM must be Numpy arrays of dtype float32") - if len(rpkm) < 20000: - warnings.warn( - f"WARNING: Creating DataLoader with only {len(rpkm)} sequences. " - "We normally expect 20,000 sequences or more to prevent overfitting. " - "As a deep learning model, VAEs are prone to overfitting with too few sequences. " - "You may want to lower the beta parameter, or use a different binner altogether." - ) - # Copy if not destroy - this way we can have all following operations in-place # for simplicity if not destroy: diff --git a/vamb/parsecontigs.py b/vamb/parsecontigs.py index 42e8b023..5cd3391b 100644 --- a/vamb/parsecontigs.py +++ b/vamb/parsecontigs.py @@ -5,12 +5,14 @@ ... tnfs, contignames, lengths = read_contigs(filehandle) """ +import sys import os as _os import numpy as _np import vamb.vambtools as _vambtools from collections.abc import Iterable, Sequence -from typing import IO, Union, TypeVar +from typing import IO, Union, TypeVar, Optional, IO from pathlib import Path +import warnings # This kernel is created in src/create_kernel.py. See that file for explanation _KERNEL: _np.ndarray = _vambtools.read_npz( @@ -158,7 +160,12 @@ def _convert(raw: _vambtools.PushArray, projected: _vambtools.PushArray): raw.clear() @classmethod - def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) -> C: + def from_file( + cls: type[C], + filehandle: Iterable[bytes], + minlength: int = 100, + logfile: Optional[IO[str]] = None, + ) -> C: """Parses a FASTA file open in binary reading mode, returning Composition. Input: @@ -206,4 +213,18 @@ def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) - _np.array(mask, dtype=bool), minlength, ) + + if len(metadata.lengths) < 20_000: + message = ( + f"WARNING: Parsed only {len(metadata.lengths)} sequences from FASTA file. " + "We normally expect 20,000 sequences or more to prevent overfitting. " + "As a deep learning model, VAEs are prone to overfitting with too few sequences. " + "You may want to bin more samples as a time, lower the beta parameter, " + "or use a different binner altogether." + ) + warnings.warn(message) + if logfile is not None: + print("\n", file=logfile) + print(message, file=logfile) + return cls(metadata, tnfs_arr) diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 3bb97403..9cefec49 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -13,6 +13,18 @@ from collections.abc import Iterable, Iterator, Generator from typing import Optional, IO, Union from pathlib import Path +import warnings + + +def showwarning_override(message, category, filename, lineno, file=None, line=None): + print(str(message) + "\n", file=file) + + +# It may seem horrifying to override a stdlib method, but this is the way recommended by the +# warnings documentation. +# We do it because it's the only way I know to prevent displaying file numbers and source +# code to our users, which I think is a terrible user experience +warnings.showwarning = showwarning_override class PushArray: