Skip to content

Commit

Permalink
Warn users on too few contigs
Browse files Browse the repository at this point in the history
Now warn users already when parsing contigs if they have fewer than 20,000
sequences, in order to prevent users from accidentally overfit the VAE.
  • Loading branch information
jakobnissen committed Nov 6, 2023
1 parent 9a8a66d commit 264b5a2
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 21 deletions.
4 changes: 0 additions & 4 deletions test/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def test_samples_too_small(self):
t = self.tnfs.copy()
l = self.lens.copy()

with self.assertWarns(UserWarning):
dl = vamb.encode.make_dataloader(r, t, l, batchsize=256)
vae.trainmodel(dl, batchsteps=None, nepochs=2)

def test_loss_falls(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
Expand Down
35 changes: 30 additions & 5 deletions test/test_parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import unittest
import random
import numpy as np
import warnings

import testtools
from vamb.parsecontigs import Composition, CompositionMetaData


class TestReadContigs(unittest.TestCase):
records = []
large_io = io.BytesIO()
io = io.BytesIO()

@classmethod
Expand All @@ -21,8 +23,14 @@ def setUpClass(cls):
cls.io.write(i.format().encode())
cls.io.write(b"\n")

for i in range(25_000):
record = testtools.make_randseq(rng, 250, 300)
cls.large_io.write(record.format().encode())
cls.large_io.write(b"\n")

def setUp(self):
self.io.seek(0)
self.large_io.seek(0)

def test_unique_names(self):
with self.assertRaises(ValueError):
Expand All @@ -33,9 +41,22 @@ def test_unique_names(self):
1000,
)

# Does not warn
def test_nowarn(self):
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
Composition.from_file(self.large_io, minlength=250)

def test_warns_n_contigs(self):
with self.assertWarns(UserWarning):
Composition.from_file(self.io, minlength=250)

def test_filter_minlength(self):
minlen = 500
composition = Composition.from_file(self.io, minlength=450)

with self.assertWarns(UserWarning):
composition = Composition.from_file(self.io, minlength=450)

md = composition.metadata
hash1 = md.refhash

Expand Down Expand Up @@ -74,7 +95,8 @@ def test_minlength(self):
Composition.from_file(self.io, minlength=3)

def test_properties(self):
composition = Composition.from_file(self.io, minlength=420)
with self.assertWarns(UserWarning):
composition = Composition.from_file(self.io, minlength=420)
passed = list(filter(lambda x: len(x.sequence) >= 420, self.records))

self.assertEqual(composition.nseqs, len(composition.metadata.identifiers))
Expand All @@ -96,7 +118,8 @@ def test_properties(self):

def test_save_load(self):
buf = io.BytesIO()
composition_1 = Composition.from_file(self.io)
with self.assertWarns(UserWarning):
composition_1 = Composition.from_file(self.io)
md1 = composition_1.metadata
composition_1.save(buf)
buf.seek(0)
Expand Down Expand Up @@ -126,8 +149,10 @@ def test_windows_newlines(self):

buf1.seek(0)
buf2.seek(0)
comp1 = Composition.from_file(buf1)
comp2 = Composition.from_file(buf2)
with self.assertWarns(UserWarning):
comp1 = Composition.from_file(buf1)
with self.assertWarns(UserWarning):
comp2 = Composition.from_file(buf2)

self.assertEqual(comp1.metadata.refhash, comp2.metadata.refhash)
self.assertTrue(np.all(comp1.matrix == comp2.matrix))
2 changes: 1 addition & 1 deletion vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def calc_tnf(
log(f"Loading data from FASTA file {path}", logfile, 1)
with vamb.vambtools.Reader(str(path)) as file:
composition = vamb.parsecontigs.Composition.from_file(
file, minlength=options.min_contig_length
file, minlength=options.min_contig_length, logfile=logfile
)
composition.save(outdir.joinpath("composition.npz"))

Expand Down
9 changes: 0 additions & 9 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import Tensor
from torch import nn as _nn
from math import log as _log
import warnings

__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.
Expand Down Expand Up @@ -87,14 +86,6 @@ def make_dataloader(
if not (rpkm.dtype == tnf.dtype == _np.float32):
raise ValueError("TNF and RPKM must be Numpy arrays of dtype float32")

if len(rpkm) < 20000:
warnings.warn(
f"WARNING: Creating DataLoader with only {len(rpkm)} sequences. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to lower the beta parameter, or use a different binner altogether."
)

# Copy if not destroy - this way we can have all following operations in-place
# for simplicity
if not destroy:
Expand Down
25 changes: 23 additions & 2 deletions vamb/parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
... tnfs, contignames, lengths = read_contigs(filehandle)
"""

import sys
import os as _os
import numpy as _np
import vamb.vambtools as _vambtools
from collections.abc import Iterable, Sequence
from typing import IO, Union, TypeVar
from typing import IO, Union, TypeVar, Optional, IO
from pathlib import Path
import warnings

# This kernel is created in src/create_kernel.py. See that file for explanation
_KERNEL: _np.ndarray = _vambtools.read_npz(
Expand Down Expand Up @@ -158,7 +160,12 @@ def _convert(raw: _vambtools.PushArray, projected: _vambtools.PushArray):
raw.clear()

@classmethod
def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) -> C:
def from_file(
cls: type[C],
filehandle: Iterable[bytes],
minlength: int = 100,
logfile: Optional[IO[str]] = None,
) -> C:
"""Parses a FASTA file open in binary reading mode, returning Composition.
Input:
Expand Down Expand Up @@ -206,4 +213,18 @@ def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) -
_np.array(mask, dtype=bool),
minlength,
)

if len(metadata.lengths) < 20_000:
message = (
f"WARNING: Parsed only {len(metadata.lengths)} sequences from FASTA file. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to bin more samples as a time, lower the beta parameter, "
"or use a different binner altogether."
)
warnings.warn(message)
if logfile is not None:
print("\n", file=logfile)
print(message, file=logfile)

return cls(metadata, tnfs_arr)
12 changes: 12 additions & 0 deletions vamb/vambtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
from collections.abc import Iterable, Iterator, Generator
from typing import Optional, IO, Union
from pathlib import Path
import warnings


def showwarning_override(message, category, filename, lineno, file=None, line=None):
print(str(message) + "\n", file=file)


# It may seem horrifying to override a stdlib method, but this is the way recommended by the
# warnings documentation.
# We do it because it's the only way I know to prevent displaying file numbers and source
# code to our users, which I think is a terrible user experience
warnings.showwarning = showwarning_override


class PushArray:
Expand Down

0 comments on commit 264b5a2

Please sign in to comment.