Skip to content

Commit

Permalink
Warn users on too few contigs
Browse files Browse the repository at this point in the history
Now warn users already when parsing contigs if they have fewer than 20,000
sequences, in order to prevent users from accidentally overfit the VAE.
  • Loading branch information
jakobnissen committed Nov 6, 2023
1 parent 9a8a66d commit c460bdb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def calc_tnf(
log(f"Loading data from FASTA file {path}", logfile, 1)
with vamb.vambtools.Reader(str(path)) as file:
composition = vamb.parsecontigs.Composition.from_file(
file, minlength=options.min_contig_length
file, minlength=options.min_contig_length, logfile=logfile
)
composition.save(outdir.joinpath("composition.npz"))

Expand Down
9 changes: 0 additions & 9 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch import Tensor
from torch import nn as _nn
from math import log as _log
import warnings

__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.
Expand Down Expand Up @@ -87,14 +86,6 @@ def make_dataloader(
if not (rpkm.dtype == tnf.dtype == _np.float32):
raise ValueError("TNF and RPKM must be Numpy arrays of dtype float32")

if len(rpkm) < 20000:
warnings.warn(
f"WARNING: Creating DataLoader with only {len(rpkm)} sequences. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to lower the beta parameter, or use a different binner altogether."
)

# Copy if not destroy - this way we can have all following operations in-place
# for simplicity
if not destroy:
Expand Down
24 changes: 22 additions & 2 deletions vamb/parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
... tnfs, contignames, lengths = read_contigs(filehandle)
"""

import sys
import os as _os
import numpy as _np
import vamb.vambtools as _vambtools
from collections.abc import Iterable, Sequence
from typing import IO, Union, TypeVar
from typing import IO, Union, TypeVar, Optional, IO
from pathlib import Path

# This kernel is created in src/create_kernel.py. See that file for explanation
Expand Down Expand Up @@ -158,7 +159,12 @@ def _convert(raw: _vambtools.PushArray, projected: _vambtools.PushArray):
raw.clear()

@classmethod
def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) -> C:
def from_file(
cls: type[C],
filehandle: Iterable[bytes],
minlength: int = 100,
logfile: Optional[IO[str]] = None,
) -> C:
"""Parses a FASTA file open in binary reading mode, returning Composition.
Input:
Expand Down Expand Up @@ -206,4 +212,18 @@ def from_file(cls: type[C], filehandle: Iterable[bytes], minlength: int = 100) -
_np.array(mask, dtype=bool),
minlength,
)

if len(metadata.lengths) < 20_000:
message = (
f"WARNING: Parsed only {len(metadata.lengths)} sequences from FASTA file. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to bin more samples as a time, lower the beta parameter, "
"or use a different binner altogether."
)
print(message, file=sys.stderr)
if logfile is not None:
print("\n", file=logfile)
print(message, file=logfile)

return cls(metadata, tnfs_arr)

0 comments on commit c460bdb

Please sign in to comment.