Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the dataloader mask #216

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 10 additions & 37 deletions test/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@ def not_nearly_same(self, A, B):
def test_bad_args(self):
# Bad rpkm
with self.assertRaises(ValueError):
vamb.encode.make_dataloader(
[[1, 2, 3]], self.tnfs, np.ndarray([2000]), batchsize=32
)
vamb.encode.make_dataloader([[1, 2, 3]], self.tnfs, self.lens, batchsize=32)

# bad tnfs
with self.assertRaises(ValueError):
vamb.encode.make_dataloader(
self.rpkm, [[1, 2, 3]], np.ndarray([2000]), batchsize=32
)
vamb.encode.make_dataloader(self.rpkm, [[1, 2, 3]], self.lens, batchsize=32)

# Bad batchsize
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -58,13 +54,11 @@ def test_destroy(self):
copy_rpkm = self.rpkm.copy()
copy_tnfs = self.tnfs.copy()

(dl, mask) = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=32
)
_ = vamb.encode.make_dataloader(self.rpkm, self.tnfs, self.lens, batchsize=32)
self.nearly_same(self.rpkm, copy_rpkm)
self.nearly_same(self.tnfs, copy_tnfs)

(dl, mask) = vamb.encode.make_dataloader(
_ = vamb.encode.make_dataloader(
copy_rpkm, copy_tnfs, self.lens, batchsize=32, destroy=True
)
self.not_nearly_same(self.rpkm, copy_rpkm)
Expand All @@ -74,7 +68,7 @@ def test_normalized(self):
copy_rpkm = self.rpkm.copy()
copy_tnfs = self.tnfs.copy()

(dl, mask) = vamb.encode.make_dataloader(
_ = vamb.encode.make_dataloader(
copy_rpkm, copy_tnfs, self.lens, batchsize=32, destroy=True
)

Expand All @@ -87,25 +81,10 @@ def test_normalized(self):
self.nearly_same(np.sum(copy_rpkm, axis=1), np.ones(copy_rpkm.shape[0]))
self.assertTrue(np.all(copy_rpkm >= 0.0))

def test_mask(self):
copy_rpkm = self.rpkm.copy()
copy_tnfs = self.tnfs.copy()
mask = np.ones(len(copy_rpkm)).astype(bool)

for bad_tnf in [0, 4, 9]:
copy_tnfs[bad_tnf, :] = 0
mask[bad_tnf] = False

(dl, mask2) = vamb.encode.make_dataloader(
copy_rpkm, copy_tnfs, self.lens, batchsize=32
)

self.assertTrue(np.all(mask == mask2))

def test_single_sample(self):
single_rpkm = self.rpkm[:, [0]]
copy_single = single_rpkm.copy()
(dl, mask) = vamb.encode.make_dataloader(
dl = vamb.encode.make_dataloader(
single_rpkm, self.tnfs.copy(), self.lens, batchsize=32, destroy=True
)
# When destroying a single sample, RPKM is set to 1.0
Expand All @@ -126,9 +105,7 @@ def test_single_sample(self):

def test_iter(self):
bs = 32
(dl, mask) = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=bs
)
dl = vamb.encode.make_dataloader(self.rpkm, self.tnfs, self.lens, batchsize=bs)

# Check right element type
for M in next(iter(dl)):
Expand All @@ -140,9 +117,7 @@ def test_iter(self):
self.nearly_same(np.sum(rpkm.numpy(), axis=1), np.ones(bs))

def test_randomized(self):
(dl, mask) = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=64
)
dl = vamb.encode.make_dataloader(self.rpkm, self.tnfs, self.lens, batchsize=64)
rpkm, tnfs, abundances, weights = next(iter(dl))

# Test that first batch is not just the first 64 elements.
Expand Down Expand Up @@ -184,7 +159,7 @@ def test_loss_falls(self):
vae = vamb.encode.VAE(self.rpkm.shape[1])
rpkm_copy = self.rpkm.copy()
tnfs_copy = self.tnfs.copy()
dl, mask = vamb.encode.make_dataloader(
dl = vamb.encode.make_dataloader(
rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True
)
(di, ti, ai, we) = next(iter(dl))
Expand Down Expand Up @@ -213,9 +188,7 @@ def test_loss_falls(self):
def test_encoding(self):
nlatent = 15
vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent)
dl, mask = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=32
)
dl = vamb.encode.make_dataloader(self.rpkm, self.tnfs, self.lens, batchsize=32)
encoding = vae.encode(dl)
self.assertEqual(encoding.dtype, np.float32)
self.assertEqual(encoding.shape, (len(self.rpkm), nlatent))
5 changes: 2 additions & 3 deletions test/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,13 @@ def test_runs(self):
)

vae = vamb.encode.VAE(6)
dl, mask = vamb.encode.make_dataloader(
dl = vamb.encode.make_dataloader(
self.rpkm.copy(), self.tnfs, self.lens, batchsize=16
)
vae.trainmodel(dl, nepochs=3, batchsteps=[1, 2])
latent = vae.encode(dl)

self.assertIsInstance(latent, np.ndarray)
self.assertIsInstance(mask, np.ndarray)

if TEST_UNSTABLE_HASHES:

Expand All @@ -108,7 +107,7 @@ def test_result(self):
np.random.seed(0)
random.seed(0)
vae = vamb.encode.VAE(6)
dl, mask = vamb.encode.make_dataloader(
dl = vamb.encode.make_dataloader(
self.rpkm, self.tnfs, self.lens, batchsize=16
)
vae.trainmodel(dl, nepochs=3, batchsteps=[1, 2])
Expand Down
10 changes: 1 addition & 9 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,22 +765,14 @@ def run(
1,
)

data_loader, mask = vamb.encode.make_dataloader(
data_loader = vamb.encode.make_dataloader(
abundance.matrix,
composition.matrix,
composition.metadata.lengths,
256, # dummy value - we change this before using the actual loader
destroy=True,
cuda=vamb_options.cuda,
)
composition.metadata.filter_mask(mask)

print("", file=logfile)
log("Created dataloader and mask", logfile, 0)
vamb.vambtools.write_npz(vamb_options.out_dir.joinpath("mask.npz"), mask)
n_discarded = len(mask) - mask.sum()
log(f"Number of sequences unsuitable for encoding: {n_discarded}", logfile, 1)
log(f"Number of sequences remaining: {len(mask) - n_discarded}", logfile, 1)

latent = None
if vae_options is not None:
Expand Down
39 changes: 21 additions & 18 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

Usage:
>>> vae = VAE(nsamples=6)
>>> dataloader, mask = make_dataloader(depths, tnf, lengths)
>>> dataloader = make_dataloader(depths, tnf, lengths)
>>> vae.trainmodel(dataloader)
>>> latent = vae.encode(dataloader) # Encode to latent representation
>>> latent.shape
Expand Down Expand Up @@ -56,24 +56,22 @@ def make_dataloader(
batchsize: int = 256,
destroy: bool = False,
cuda: bool = False,
) -> tuple[_DataLoader[tuple[Tensor, Tensor, Tensor]], _np.ndarray]:
"""Create a DataLoader and a contig mask from RPKM and TNF.
) -> _DataLoader:
"""Create a DataLoader from RPKM, TNF and lengths.

The dataloader is an object feeding minibatches of contigs to the VAE.
The data are normalized versions of the input datasets, with zero-contigs,
i.e. contigs where a row in either TNF or RPKM are all zeros, removed.
The mask is a boolean mask designating which contigs have been kept.
The data are normalized versions of the input datasets.

Inputs:
rpkm: RPKM matrix (N_contigs x N_samples)
tnf: TNF matrix (N_contigs x N_TNF)
lengths: Numpy array of sequence length (N_contigs)
batchsize: Starting size of minibatches for dataloader
destroy: Mutate rpkm and tnf array in-place instead of making a copy.
cuda: Pagelock memory of dataloader (use when using GPU acceleration)

Outputs:
DataLoader: An object feeding data to the VAE
mask: A boolean mask of which contigs are kept
"""

if not isinstance(rpkm, _np.ndarray) or not isinstance(tnf, _np.ndarray):
Expand All @@ -88,7 +86,13 @@ def make_dataloader(
if not (rpkm.dtype == tnf.dtype == _np.float32):
raise ValueError("TNF and RPKM must be Numpy arrays of dtype float32")

### Copy arrays and mask them ###
if len(lengths) < batchsize:
raise ValueError(
"Fewer sequences left after filtering than the batch size. "
+ "This probably means you try to run on a too small dataset (below ~5k sequences), "
+ "Check the log file, and verify BAM file content is sensible."
)

# Copy if not destroy - this way we can have all following operations in-place
# for simplicity
if not destroy:
Expand All @@ -103,16 +107,15 @@ def make_dataloader(
)
rpkm *= 1_000_000 / sample_depths_sum

mask = tnf.sum(axis=1) != 0
if mask.sum() < batchsize:
zero_tnf = tnf.sum(axis=1) == 0
smallest_index = _np.argmax(zero_tnf)
if zero_tnf[smallest_index]:
raise ValueError(
"Fewer sequences left after filtering than the batch size. "
+ "This probably means you try to run on a too small dataset (below ~10k sequences), "
+ "or that nearly all sequences were filtered away. Check the log file, "
+ "and verify BAM file content is sensible."
f"TNF row at index {smallest_index} is all zeros. "
+ "This implies that the sequence contained no 4-mers of A, C, G, T or U, "
+ "making this sequence uninformative. This is probably a mistake. "
+ "Verify that the sequence contains usable information (e.g. is not all N's)"
)
_vambtools.numpy_inplace_maskarray(rpkm, mask)
_vambtools.numpy_inplace_maskarray(tnf, mask)

total_abundance = rpkm.sum(axis=1)

Expand All @@ -131,7 +134,7 @@ def make_dataloader(
total_abundance.shape = (len(total_abundance), 1)

# Create weights
lengths = (lengths[mask]).astype(_np.float32)
lengths = (lengths).astype(_np.float32)
weights = _np.log(lengths).astype(_np.float32) - 5.0
weights[weights < 2.0] = 2.0
weights *= len(weights) / weights.sum()
Expand All @@ -155,7 +158,7 @@ def make_dataloader(
pin_memory=cuda,
)

return dataloader, mask
return dataloader


class VAE(_nn.Module):
Expand Down
1 change: 0 additions & 1 deletion workflow_avamb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ Avamb produces the following output files:
- `contignames`: text file containing a list of the contigs remaining after the minimum contig size allowed, and defined on the `min_contig_size` in the `config.json` file.
- `lengths.npz`: Numpy object that contains the contig length, same order than the contignames.
- `log.txt`: a text file with information about the Avamb run. Look here (and at stderr) if you experience errors.
- `mask.npz`: considering the contigs abundances and tetra nucleotide frequencies computed per contig, some contigs might have been filtered out before binning, this numpy boolean object contains this masking.
- `model.pt`: a file containing the trained VAE model. When running Avamb from a Python interpreter, the VAE can be loaded from this file to skip training.
- `aae_model.pt`: a file containing the trained AAE model. When running Avamb from a Python interpreter, the AAE can be loaded from this file to skip training.
- `vae_clusters.tsv`: file generated by clustering the VAE latent space, where each row is a sequence: Left column for the cluster (i.e bin) name, right column for the sequence name. You can create the FASTA-file bins themselves using the script in `src/create_fasta.py`
Expand Down