Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not error with too small batch size #231

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ build-backend = "setuptools.build_meta"

[tool.ruff]
ignore = ["E501"]

# pyproject.toml
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore::UserWarning",
]
14 changes: 10 additions & 4 deletions test/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ def test_bad_args(self):
self.rpkm.astype(np.float64), self.tnfs, self.lens, batchsize=32
)

# Batchsize is too large
with self.assertRaises(ValueError):
vamb.encode.make_dataloader(self.rpkm, self.tnfs, self.lens, batchsize=256)

def test_destroy(self):
copy_rpkm = self.rpkm.copy()
copy_tnfs = self.tnfs.copy()
Expand Down Expand Up @@ -155,6 +151,16 @@ def test_bad_args(self):
with self.assertRaises(ValueError):
vamb.encode.VAE(5, dropout=-0.001)

def test_samples_too_small(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
r = self.rpkm.copy()
t = self.tnfs.copy()
l = self.lens.copy()

with self.assertWarns(UserWarning):
dl = vamb.encode.make_dataloader(r, t, l, batchsize=64)
vae.trainmodel(dl, batchsteps=None, nepochs=2)

def test_loss_falls(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
Expand Down
23 changes: 7 additions & 16 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import Tensor
from torch import nn as _nn
from math import log as _log
import warnings

__doc__ = """Encode a depths matrix and a tnf matrix to latent representation.

Expand Down Expand Up @@ -86,11 +87,12 @@ def make_dataloader(
if not (rpkm.dtype == tnf.dtype == _np.float32):
raise ValueError("TNF and RPKM must be Numpy arrays of dtype float32")

if len(lengths) < batchsize:
raise ValueError(
"Fewer sequences left after filtering than the batch size. "
+ "This probably means you try to run on a too small dataset (below ~5k sequences), "
+ "Check the log file, and verify BAM file content is sensible."
if len(rpkm) < 20000:
warnings.warn(
f"WARNING: Creating DataLoader with only {len(rpkm)} sequences. "
"We normally expect 20,000 sequences or more to prevent overfitting. "
"As a deep learning model, VAEs are prone to overfitting with too few sequences. "
"You may want to lower the beta parameter, or use a different binner altogether."
)

# Copy if not destroy - this way we can have all following operations in-place
Expand Down Expand Up @@ -579,17 +581,6 @@ def trainmodel(
raise ValueError("All elements of batchsteps must be integers")
if max(batchsteps, default=0) >= nepochs:
raise ValueError("Max batchsteps must not equal or exceed nepochs")
last_batchsize = dataloader.batch_size * 2 ** len(batchsteps)
if len(dataloader.dataset) < last_batchsize: # type: ignore
raise ValueError(
f"Last batch size of {last_batchsize} exceeds dataset length "
f"of {len(dataloader.dataset)}. " # type: ignore
"This means you have too few contigs left after filtering to train. "
"It is not adviced to run Vamb with fewer than 10,000 sequences "
"after filtering. "
"Please check the Vamb log file to see where the sequences were "
"filtered away, and verify BAM files has sensible content."
)
batchsteps_set = set(batchsteps)

# Get number of features
Expand Down