From 4eda306ed59e5b9c14d65a6b4e7f90ba4956621b Mon Sep 17 00:00:00 2001 From: svekut Date: Fri, 13 Oct 2023 15:55:04 +0200 Subject: [PATCH] feat: add absolute abundance --- vamb/h_loss.py | 165 ++++++++--------- vamb/semisupervised_encode.py | 168 +++++++++++------- .../src/shortread_CAMI2_no_predictor.sh | 22 +-- .../src/shortread_CAMI2_predictor.sh | 2 +- .../src/shortread_marine_predictor.sh | 2 +- .../src/shortread_rhizo_predictor.sh | 2 +- 6 files changed, 194 insertions(+), 167 deletions(-) diff --git a/vamb/h_loss.py b/vamb/h_loss.py index 1a1f6e70..b368d525 100644 --- a/vamb/h_loss.py +++ b/vamb/h_loss.py @@ -71,28 +71,33 @@ def collate_fn_concat_hloss(num_categories, table_parent, batch): a = _torch.stack([i[0] for i in batch]) b = _torch.stack([i[1] for i in batch]) c = _torch.stack([i[2] for i in batch]) - d = [(i[3],) for i in batch] - return a, b, c, collate_fn_labels_hloss(num_categories, table_parent, d)[0] + d = _torch.stack([i[3] for i in batch]) + e = [i[4] for i in batch] + return a, b, c, d, collate_fn_labels_hloss(num_categories, table_parent, e)[0] def collate_fn_semisupervised_hloss(num_categories, table_parent, batch): a = _torch.stack([i[0] for i in batch]) b = _torch.stack([i[1] for i in batch]) c = _torch.stack([i[2] for i in batch]) - d = [i[3] for i in batch] - e = _torch.stack([i[4] for i in batch]) + d = _torch.stack([i[3] for i in batch]) + e = [i[4] for i in batch] f = _torch.stack([i[5] for i in batch]) g = _torch.stack([i[6] for i in batch]) - h = [i[7] for i in batch] + h = _torch.stack([i[7] for i in batch]) + m = _torch.stack([i[8] for i in batch]) + n = [i[9] for i in batch] return ( a, b, c, - collate_fn_labels_hloss(num_categories, table_parent, d)[0], - e, + d, + collate_fn_labels_hloss(num_categories, table_parent, e)[0], f, g, - collate_fn_labels_hloss(num_categories, table_parent, h)[0], + h, + m, + collate_fn_labels_hloss(num_categories, table_parent, n)[0], ) @@ -107,7 +112,7 @@ def make_dataloader_labels_hloss( destroy=False, cuda=False, ): - _, _, _, batchsize, n_workers, cuda, mask = _semisupervised_encode._make_dataset( + _, _, _, _, batchsize, n_workers, cuda, mask = _semisupervised_encode._make_dataset( rpkm, tnf, lengths, batchsize=batchsize, destroy=destroy, cuda=cuda ) dataset = _TensorDataset(_torch.Tensor(labels).long()) @@ -139,6 +144,7 @@ def make_dataloader_concat_hloss( ( depthstensor, tnftensor, + total_abundance_tensor, weightstensor, batchsize, n_workers, @@ -149,7 +155,7 @@ def make_dataloader_concat_hloss( ) labelstensor = _torch.Tensor(labels).long() dataset = _TensorDataset( - depthstensor, tnftensor, weightstensor, labelstensor + depthstensor, tnftensor, total_abundance_tensor, weightstensor, labelstensor ) dataloader = _DataLoader( dataset=dataset, @@ -194,11 +200,13 @@ def make_dataloader_semisupervised_hloss( dataloader_vamb.dataset.tensors[0][indices_unsup_vamb], dataloader_vamb.dataset.tensors[1][indices_unsup_vamb], dataloader_vamb.dataset.tensors[2][indices_unsup_vamb], + dataloader_vamb.dataset.tensors[3][indices_unsup_vamb], dataloader_labels.dataset.tensors[0][indices_unsup_labels], dataloader_joint.dataset.tensors[0][indices_sup], dataloader_joint.dataset.tensors[1][indices_sup], dataloader_joint.dataset.tensors[2][indices_sup], dataloader_joint.dataset.tensors[3][indices_sup], + dataloader_joint.dataset.tensors[4][indices_sup], ) dataloader_all = _DataLoader( dataset=dataset_all, @@ -344,18 +352,6 @@ def calc_loss(self, labels_in, labels_out, mu, logsigma): loss = ce_labels * ce_labels_weight + kld * kld_weight return loss, ce_labels, kld, _torch.tensor(0) - # _, labels_in_indices = labels_in.max(dim=1) - # with _torch.no_grad(): - # gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()] - # prob = self.pred_fn(labels_out) - # prob = prob.cpu().numpy() - # pred = _infer.argmax_with_confidence( - # self.specificity, prob, 0.5, self.not_trivial - # ) - # pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred)) - - # return loss, ce_labels, kld, _torch.tensor(_np.sum(gt_node == pred)) - def trainmodel( self, dataloader: _DataLoader[tuple[Tensor]], @@ -520,47 +516,44 @@ def calc_loss( depths_out, tnf_in, tnf_out, + abundance_in, + abundance_out, labels_in, labels_out, mu, logsigma, weights, ): - # If multiple samples, use cross entropy, else use SSE for abundance - if self.nsamples > 1: - # Add 1e-9 to depths_out to avoid numerical instability. - ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean() - ce_weight = (1 - self.alpha) / _log(self.nsamples) - else: - ce = (depths_out - depths_in).pow(2).sum(dim=1).mean() - ce_weight = 1 - self.alpha + ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1) + # Add 1e-9 to depths_out to avoid numerical instability. + ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1) + sse = (tnf_out - tnf_in).pow(2).sum(dim=1) + kld = 0.5 * (mu.pow(2)).sum(dim=1) - ce_labels = self.loss_fn(labels_out, labels_in) + # Avoid having the denominator be zero + if self.nsamples == 1: + ce_weight = 0.0 + else: + ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / ( + self.nsamples * _log(self.nsamples) + ) - ce_labels_weight = 1.0 - sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean() - kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean() + ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples) sse_weight = self.alpha / self.ntnf kld_weight = 1 / (self.nlatent * self.beta) + + ce_labels = self.loss_fn(labels_out, labels_in) + ce_labels_weight = 1.0 + loss = ( - ce * ce_weight + ab_sse*ab_sse_weight + + ce * ce_weight + sse * sse_weight + ce_labels * ce_labels_weight + kld * kld_weight ) * weights return loss, ce, sse, ce_labels, kld, _torch.tensor(0) - # _, labels_in_indices = labels_in.max(dim=1) - # with _torch.no_grad(): - # gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()] - # prob = self.pred_fn(labels_out) - # prob = prob.cpu().numpy() - # pred = _infer.argmax_with_confidence( - # self.specificity, prob, 0.5, self.not_trivial - # ) - # pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred)) - # return loss, ce, sse, ce_labels, kld, _torch.tensor(_np.sum(gt_node == pred)) - def kld_gauss(p_mu, p_logstd, q_mu, q_logstd): loss = ( @@ -710,6 +703,8 @@ def calc_loss_joint( depths_out, tnf_in, tnf_out, + abundance_in, + abundance_out, labels_in, labels_out, mu_sup, @@ -720,48 +715,43 @@ def calc_loss_joint( logsigma_labels_unsup, weights, ): - # If multiple samples, use cross entropy, else use SSE for abundance - if self.VAEVamb.nsamples > 1: - # Add 1e-9 to depths_out to avoid numerical instability. - ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean() - ce_weight = (1 - self.VAEVamb.alpha) / _log(self.VAEVamb.nsamples) + ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1) + # Add 1e-9 to depths_out to avoid numerical instability. + ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1) + sse = (tnf_out - tnf_in).pow(2).sum(dim=1) + + # Avoid having the denominator be zero + if self.VAEVamb.nsamples == 1: + ce_weight = 0.0 else: - ce = (depths_out - depths_in).pow(2).sum(dim=1).mean() - ce_weight = 1 - self.VAEVamb.alpha + ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / ( + self.nsamples * _log(self.nsamples) + ) + + ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples) + sse_weight = self.alpha / self.ntnf - # _, labels_in_indices = labels_in.max(dim=1) - # ce_labels = _nn.CrossEntropyLoss()(labels_out, labels_in_indices) ce_labels = self.VAEJoint.loss_fn(labels_out, labels_in) ce_labels_weight = 1.0 # TODO: figure out - sse = (tnf_out - tnf_in).pow(2).sum(dim=1) - sse_weight = self.VAEVamb.alpha / self.VAEVamb.ntnf - kld_weight = 1 / (self.VAEVamb.nlatent * self.VAEVamb.beta) - reconstruction_loss = ( - ce * ce_weight + sse * sse_weight + ce_labels * ce_labels_weight - ) + weighed_ab = ab_sse * ab_sse_weight + weighed_ce = ce * ce_weight + weighed_sse = sse * sse_weight + weighed_labels = ce_labels * ce_labels_weight + + reconstruction_loss = weighed_ce + weighed_ab + weighed_sse + weighed_labels kld_vamb = kld_gauss(mu_sup, logsigma_sup, mu_vamb_unsup, logsigma_vamb_unsup) kld_labels = kld_gauss( mu_sup, logsigma_sup, mu_labels_unsup, logsigma_labels_unsup ) kld = kld_vamb + kld_labels + kld_weight = 1 / (self.nlatent * self.beta) kld_loss = kld * kld_weight loss = (reconstruction_loss + kld_loss) * weights return loss.mean(), ce.mean(), sse.mean(), ce_labels.mean(), kld_vamb.mean(), kld_labels.mean(), _torch.tensor(0) - # _, labels_in_indices = labels_in.max(dim=1) - # with _torch.no_grad(): - # gt_node = self.VAEJoint.eval_label_map.to_node[labels_in_indices.cpu()] - # prob = self.VAEJoint.pred_fn(labels_out) - # prob = prob.cpu().numpy() - # pred = _infer.argmax_with_confidence( - # self.VAEJoint.specificity, prob, 0.5, self.VAEJoint.not_trivial - # ) - # pred = _hier.truncate_given_lca(gt_node, pred, self.VAEJoint.find_lca(gt_node, pred)) - # return loss.mean(), ce.mean(), sse.mean(), ce_labels.mean(), kld_vamb.mean(), kld_labels.mean(), _torch.tensor(_np.sum(gt_node == pred)) - class VAMB2Label(_nn.Module): """Label predictor, subclass of torch.nn.Module. @@ -847,7 +837,7 @@ def __init__( # Add all other hidden layers for nin, nout in zip( - [self.nsamples + self.ntnf] + self.nhiddens, self.nhiddens + [self.nsamples + self.ntnf + 1] + self.nhiddens, self.nhiddens ): self.encoderlayers.append(_nn.Linear(nin, nout)) self.encodernorms.append(_nn.BatchNorm1d(nout)) @@ -899,8 +889,8 @@ def _predict(self, tensor: Tensor) -> tuple[Tensor, Tensor]: labels_out = reconstruction.narrow(1, 0, self.nlabels) return labels_out - def forward(self, depths, tnf, weights): - tensor = _torch.cat((depths, tnf), 1) + def forward(self, depths, tnf, abundances, weights): + tensor = _torch.cat((depths, tnf, abundances), 1) labels_out = self._predict(tensor) return labels_out @@ -908,18 +898,6 @@ def calc_loss(self, labels_in, labels_out): ce_labels = self.loss_fn(labels_out, labels_in) return ce_labels, _torch.tensor(0) - # _, labels_in_indices = labels_in.max(dim=1) - # with _torch.no_grad(): - # gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()] - # prob = self.pred_fn(labels_out) - # prob = prob.cpu().numpy() - # pred = _infer.argmax_with_confidence( - # self.specificity, prob, 0.5, self.not_trivial - # ) - # pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred)) - - # return ce_labels, _torch.tensor(_np.sum(gt_node == pred)) - def predict(self, data_loader) -> _np.ndarray: self.eval() @@ -927,7 +905,7 @@ def predict(self, data_loader) -> _np.ndarray: data_loader, data_loader.batch_size, encode=True ) - depths_array, _, _ = data_loader.dataset.tensors + depths_array, _, _, _ = data_loader.dataset.tensors length = len(depths_array) # We make a Numpy array instead of a Torch array because, if we create @@ -938,15 +916,16 @@ def predict(self, data_loader) -> _np.ndarray: row = 0 with _torch.no_grad(): - for depths, tnf, weights in new_data_loader: + for depths, tnf, abundances, weights in new_data_loader: # Move input to GPU if requested if self.usecuda: depths = depths.cuda() tnf = tnf.cuda() + abundances = abundances.cuda() weights = weights.cuda() # Evaluate - labels = self(depths, tnf, weights) + labels = self(depths, tnf, abundances, weights) with _torch.no_grad(): prob = self.pred_fn(labels) if self.usecuda: @@ -1003,20 +982,22 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): collate_fn=data_loader.collate_fn, ) - for depths_in, tnf_in, weights, labels_in in data_loader: + for depths_in, tnf_in, abundances_in, weights, labels_in in data_loader: depths_in.requires_grad = True tnf_in.requires_grad = True + abundances_in.requires_grad = True labels_in.requires_grad = True if self.usecuda: depths_in = depths_in.cuda() tnf_in = tnf_in.cuda() + abundances_in = abundances_in.cuda() weights = weights.cuda() labels_in = labels_in.cuda() optimizer.zero_grad() - labels_out = self(depths_in, tnf_in, weights) + labels_out = self(depths_in, tnf_in, abundances_in, weights) loss, correct_labels = self.calc_loss(labels_in, labels_out) diff --git a/vamb/semisupervised_encode.py b/vamb/semisupervised_encode.py index ee44ba8d..54f91012 100644 --- a/vamb/semisupervised_encode.py +++ b/vamb/semisupervised_encode.py @@ -32,12 +32,14 @@ def collate_fn_concat(num_categories, batch): a = _torch.stack([i[0] for i in batch]) b = _torch.stack([i[1] for i in batch]) c = _torch.stack([i[2] for i in batch]) - d = [i[3] for i in batch] + d = _torch.stack([i[3] for i in batch]) + e = [i[4] for i in batch] return ( a, b, c, - F.one_hot(_torch.as_tensor(d), num_classes=max(num_categories, 105)) + d, + F.one_hot(_torch.as_tensor(e), num_classes=max(num_categories, 105)) .squeeze(1) .float(), ) @@ -47,22 +49,26 @@ def collate_fn_semisupervised(num_categories, batch): a = _torch.stack([i[0] for i in batch]) b = _torch.stack([i[1] for i in batch]) c = _torch.stack([i[2] for i in batch]) - d = [i[3] for i in batch] - e = _torch.stack([i[4] for i in batch]) + d = _torch.stack([i[3] for i in batch]) + e = [i[4] for i in batch] f = _torch.stack([i[5] for i in batch]) g = _torch.stack([i[6] for i in batch]) - h = [i[7] for i in batch] + h = _torch.stack([i[7] for i in batch]) + m = _torch.stack([i[8] for i in batch]) + n = [i[9] for i in batch] return ( a, b, c, - F.one_hot(_torch.as_tensor(d), num_classes=max(num_categories, 105)) + d, + F.one_hot(_torch.as_tensor(e), num_classes=max(num_categories, 105)) .squeeze(1) .float(), - e, f, g, - F.one_hot(_torch.as_tensor(h), num_classes=max(num_categories, 105)) + h, + m, + F.one_hot(_torch.as_tensor(n), num_classes=max(num_categories, 105)) .squeeze(1) .float(), ) @@ -83,8 +89,8 @@ def _make_dataset(rpkm, tnf, lengths, batchsize=256, destroy=False, cuda=False): dataloader, mask = _encode.make_dataloader( rpkm, tnf, lengths, batchsize, destroy, cuda ) - depthstensor, tnftensor, weightstensor = dataloader.dataset.tensors - return depthstensor, tnftensor, weightstensor, batchsize, n_workers, cuda, mask + depthstensor, tnftensor, total_abundance_tensor, weightstensor = dataloader.dataset.tensors + return depthstensor, tnftensor, total_abundance_tensor, weightstensor, batchsize, n_workers, cuda, mask def make_dataloader_concat( @@ -93,6 +99,7 @@ def make_dataloader_concat( ( depthstensor, tnftensor, + total_abundance_tensor, weightstensor, batchsize, n_workers, @@ -103,7 +110,7 @@ def make_dataloader_concat( ) labels_int = _np.unique(labels, return_inverse=True)[1][mask] dataset = _TensorDataset( - depthstensor, tnftensor, weightstensor, _torch.from_numpy(labels_int) + depthstensor, tnftensor, total_abundance_tensor, weightstensor, _torch.from_numpy(labels_int) ) dataloader = _DataLoader( dataset=dataset, @@ -120,7 +127,7 @@ def make_dataloader_concat( def make_dataloader_labels( rpkm, tnf, lengths, labels, batchsize=256, destroy=False, cuda=False ): - _, _, _, batchsize, n_workers, cuda, mask = _make_dataset( + _, _, _, _, batchsize, n_workers, cuda, mask = _make_dataset( rpkm, tnf, lengths, batchsize=batchsize, destroy=destroy, cuda=cuda ) labels_int = _np.unique(labels, return_inverse=True)[1][mask] @@ -168,11 +175,13 @@ def make_dataloader_semisupervised( dataloader_vamb.dataset.tensors[0][indices_unsup_vamb], dataloader_vamb.dataset.tensors[1][indices_unsup_vamb], dataloader_vamb.dataset.tensors[2][indices_unsup_vamb], + dataloader_vamb.dataset.tensors[3][indices_unsup_vamb], dataloader_labels.dataset.tensors[0][indices_unsup_labels], dataloader_joint.dataset.tensors[0][indices_sup], dataloader_joint.dataset.tensors[1][indices_sup], dataloader_joint.dataset.tensors[2][indices_sup], dataloader_joint.dataset.tensors[3][indices_sup], + dataloader_joint.dataset.tensors[4][indices_sup], ) dataloader_all = _DataLoader( dataset=dataset_all, @@ -505,28 +514,29 @@ def _decode(self, tensor): reconstruction = self.outputlayer(tensor) - # Decompose reconstruction to depths and tnf signal + # Decompose reconstruction to depths, tnf and abundance signal depths_out = reconstruction.narrow(1, 0, self.nsamples) tnf_out = reconstruction.narrow(1, self.nsamples, self.ntnf) - labels_out = reconstruction.narrow(1, self.nsamples + self.ntnf, self.nlabels) + abundance_out = reconstruction.narrow(1, self.nsamples + self.ntnf, 1) + labels_out = reconstruction.narrow(1, self.nsamples + self.ntnf + 1, self.nlabels) # If multiple samples, apply softmax if self.nsamples > 1: depths_out = _softmax(depths_out, dim=1) - labels_out = _softmax(labels_out, dim=1) + # labels_out = _softmax(labels_out, dim=1) - return depths_out, tnf_out, labels_out + return depths_out, tnf_out, abundance_out, labels_out - def forward(self, depths, tnf, labels): - tensor = _torch.cat((depths, tnf, labels), 1) + def forward(self, depths, tnf, abundance, labels): + tensor = _torch.cat((depths, tnf, abundance, labels), 1) mu = self._encode(tensor) logsigma = _torch.zeros(mu.size()) if self.usecuda: logsigma = logsigma.cuda() latent = self.reparameterize(mu) - depths_out, tnf_out, labels_out = self._decode(latent) + depths_out, tnf_out, abundance_out, labels_out = self._decode(latent) - return depths_out, tnf_out, labels_out, mu, logsigma + return depths_out, tnf_out, abundance_out, labels_out, mu, logsigma def calc_loss( self, @@ -534,34 +544,44 @@ def calc_loss( depths_out, tnf_in, tnf_out, + abundance_in, + abundance_out, labels_in, labels_out, mu, logsigma, weights, ): - # If multiple samples, use cross entropy, else use SSE for abundance - if self.nsamples > 1: - # Add 1e-9 to depths_out to avoid numerical instability. - ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean() - ce_weight = (1 - self.alpha) / _log(self.nsamples) + ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1) + # Add 1e-9 to depths_out to avoid numerical instability. + ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1) + sse = (tnf_out - tnf_in).pow(2).sum(dim=1) + kld = 0.5 * (mu.pow(2)).sum(dim=1) + + # Avoid having the denominator be zero + if self.nsamples == 1: + ce_weight = 0.0 else: - ce = (depths_out - depths_in).pow(2).sum(dim=1).mean() - ce_weight = 1 - self.alpha + ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / ( + self.nsamples * _log(self.nsamples) + ) + + ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples) + sse_weight = self.alpha / self.ntnf + kld_weight = 1 / (self.nlatent * self.beta) _, labels_in_indices = labels_in.max(dim=1) ce_labels = _nn.CrossEntropyLoss()(labels_out, labels_in_indices) ce_labels_weight = 1.0 # TODO: figure out - sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean() - kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean() - sse_weight = self.alpha / self.ntnf - kld_weight = 1 / (self.nlatent * self.beta) - loss = ( - ce * ce_weight - + sse * sse_weight - + ce_labels * ce_labels_weight - + kld * kld_weight - ) * weights + + weighed_ab = ab_sse * ab_sse_weight + weighed_ce = ce * ce_weight + weighed_sse = sse * sse_weight + weighed_kld = kld * kld_weight + weighed_labels = ce_labels * ce_labels_weight + + reconstruction_loss = weighed_ce + weighed_ab + weighed_sse + weighed_labels + loss = (reconstruction_loss + weighed_kld) * weights _, labels_out_indices = labels_out.max(dim=1) _, labels_in_indices = labels_in.max(dim=1) @@ -595,21 +615,23 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): collate_fn=data_loader.collate_fn, ) - for depths_in, tnf_in, weights, labels_in in data_loader: + for depths_in, tnf_in, abundance_in, weights, labels_in in data_loader: depths_in.requires_grad = True tnf_in.requires_grad = True + abundance_in.requires_grad = True labels_in.requires_grad = True if self.usecuda: depths_in = depths_in.cuda() tnf_in = tnf_in.cuda() + abundance_in = abundance_in.cuda() weights = weights.cuda() labels_in = labels_in.cuda() optimizer.zero_grad() - depths_out, tnf_out, labels_out, mu, logsigma = self( - depths_in, tnf_in, labels_in + depths_out, tnf_out, abundance_out, labels_out, mu, logsigma = self( + depths_in, tnf_in, abundance_in, labels_in ) loss, ce, sse, ce_labels, kld, correct_labels = self.calc_loss( @@ -617,6 +639,8 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): depths_out, tnf_in, tnf_out, + abundance_in, + abundance_out, labels_in, labels_out, mu, @@ -681,15 +705,16 @@ def encode(self, data_loader): row = 0 with _torch.no_grad(): - for depths, tnf, weights, labels in new_data_loader: + for depths, tnf, ab, weights, labels in new_data_loader: # Move input to GPU if requested if self.usecuda: depths = depths.cuda() tnf = tnf.cuda() + ab = ab.cuda() labels = labels.cuda() # Evaluate - _, _, _, mu, _ = self(depths, tnf, labels) + _, _, _, _, mu, _ = self(depths, tnf, ab, labels) if self.usecuda: mu = mu.cpu() @@ -769,6 +794,8 @@ def calc_loss_joint( depths_out, tnf_in, tnf_out, + abundance_in, + abundance_out, labels_in, labels_out, mu_sup, @@ -779,31 +806,39 @@ def calc_loss_joint( logsigma_labels_unsup, weights, ): - # If multiple samples, use cross entropy, else use SSE for abundance - if self.VAEVamb.nsamples > 1: - # Add 1e-9 to depths_out to avoid numerical instability. - ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean() - ce_weight = (1 - self.VAEVamb.alpha) / _log(self.VAEVamb.nsamples) + ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1) + # Add 1e-9 to depths_out to avoid numerical instability. + ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1) + sse = (tnf_out - tnf_in).pow(2).sum(dim=1) + + # Avoid having the denominator be zero + if self.VAEVamb.nsamples == 1: + ce_weight = 0.0 else: - ce = (depths_out - depths_in).pow(2).sum(dim=1).mean() - ce_weight = 1 - self.VAEVamb.alpha + ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / ( + self.nsamples * _log(self.nsamples) + ) + + ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples) + sse_weight = self.alpha / self.ntnf _, labels_in_indices = labels_in.max(dim=1) ce_labels = _nn.CrossEntropyLoss()(labels_out, labels_in_indices) ce_labels_weight = 1.0 # TODO: figure out - sse = (tnf_out - tnf_in).pow(2).sum(dim=1) - sse_weight = self.VAEVamb.alpha / self.VAEVamb.ntnf - kld_weight = 1 / (self.VAEVamb.nlatent * self.VAEVamb.beta) - reconstruction_loss = ( - ce * ce_weight + sse * sse_weight + ce_labels * ce_labels_weight - ) + weighed_ab = ab_sse * ab_sse_weight + weighed_ce = ce * ce_weight + weighed_sse = sse * sse_weight + weighed_labels = ce_labels * ce_labels_weight + + reconstruction_loss = weighed_ce + weighed_ab + weighed_sse + weighed_labels kld_vamb = kld_gauss(mu_sup, logsigma_sup, mu_vamb_unsup, logsigma_vamb_unsup) kld_labels = kld_gauss( mu_sup, logsigma_sup, mu_labels_unsup, logsigma_labels_unsup ) kld = kld_vamb + kld_labels + kld_weight = 1 / (self.nlatent * self.beta) kld_loss = kld * kld_weight loss = (reconstruction_loss + kld_loss) * weights @@ -856,37 +891,43 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): for ( depths_in_sup, tnf_in_sup, + abundance_in_sup, weights_in_sup, labels_in_sup, depths_in_unsup, tnf_in_unsup, + abundance_in_unsup, weights_in_unsup, labels_in_unsup, ) in data_loader: depths_in_sup.requires_grad = True tnf_in_sup.requires_grad = True + abundance_in_sup.requires_grad = True labels_in_sup.requires_grad = True depths_in_unsup.requires_grad = True tnf_in_unsup.requires_grad = True + abundance_in_unsup.requires_grad = True labels_in_unsup.requires_grad = True if self.VAEVamb.usecuda: depths_in_sup = depths_in_sup.cuda() tnf_in_sup = tnf_in_sup.cuda() + abundance_in_sup = abundance_in_sup.cuda() weights_in_sup = weights_in_sup.cuda() labels_in_sup = labels_in_sup.cuda() depths_in_unsup = depths_in_unsup.cuda() tnf_in_unsup = tnf_in_unsup.cuda() + abundance_in_unsup = abundance_in_unsup.cuda() weights_in_unsup = weights_in_unsup.cuda() labels_in_unsup = labels_in_unsup.cuda() optimizer.zero_grad() - _, _, _, mu_sup, logsigma_sup = self.VAEJoint( - depths_in_sup, tnf_in_sup, labels_in_sup + _, _, _, _, mu_sup, logsigma_sup = self.VAEJoint( + depths_in_sup, tnf_in_sup, abundance_in_sup, labels_in_sup ) # use the two-modality latent space - depths_out_sup, tnf_out_sup = self.VAEVamb._decode( + depths_out_sup, tnf_out_sup, abundance_out_sup = self.VAEVamb._decode( self.VAEVamb.reparameterize(mu_sup) ) # use the one-modality decoders labels_out_sup = self.VAELabels._decode( @@ -896,13 +937,14 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): ( depths_out_unsup, tnf_out_unsup, + abundance_out_unsup, mu_vamb_unsup, - ) = self.VAEVamb(depths_in_unsup, tnf_in_unsup) + ) = self.VAEVamb(depths_in_unsup, tnf_in_unsup, abundance_in_unsup) logsigma_vamb_unsup = _torch.zeros(mu_vamb_unsup.size()) if self.usecuda: logsigma_vamb_unsup = logsigma_vamb_unsup.cuda() - _, _, mu_vamb_sup_s = self.VAEVamb( - depths_in_sup, tnf_in_sup + _, _, _, mu_vamb_sup_s = self.VAEVamb( + depths_in_sup, tnf_in_sup, abundance_in_sup ) logsigma_vamb_sup_s = _torch.zeros(mu_vamb_sup_s.size()) if self.usecuda: @@ -922,6 +964,8 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): depths_out_unsup, tnf_in_unsup, tnf_out_unsup, + abundance_in_unsup, + abundance_out_unsup, mu_vamb_unsup, weights_in_unsup, ) @@ -943,6 +987,8 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile): depths_out_sup, tnf_in_sup, tnf_out_sup, + abundance_in_sup, + abundance_out_sup, labels_in_sup, labels_out_sup, mu_sup, diff --git a/workflow_vaevae/src/shortread_CAMI2_no_predictor.sh b/workflow_vaevae/src/shortread_CAMI2_no_predictor.sh index 8e80270d..e0027e2d 100755 --- a/workflow_vaevae/src/shortread_CAMI2_no_predictor.sh +++ b/workflow_vaevae/src/shortread_CAMI2_no_predictor.sh @@ -13,7 +13,7 @@ keyword=$3 vamb \ --model vaevae \ - --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_1000 \ + --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_abs \ --fasta /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/contigs_2kbp.fna.gz \ --rpkm /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/abundance.npz \ --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/${dataset}_taxonomy_${run_id}.tsv \ @@ -27,13 +27,13 @@ vamb \ --cuda \ --minfasta 200000 -vamb \ - --model reclustering \ - --latent_path /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_1000/vaevae_latent.npy \ - --clusters_path /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_1000/vaevae_clusters.tsv \ - --fasta /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/contigs_2kbp.fna.gz \ - --rpkm /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/abundance.npz \ - --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_reclustering_${run_id}_${keyword}_1000 \ - --hmmout_path /home/projects/cpr_10006/projects/semi_vamb/data/marker_genes/markers_cami_${dataset}.hmmout \ - --algorithm kmeans \ - --minfasta 200000 +# vamb \ +# --model reclustering \ +# --latent_path /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_1000/vaevae_latent.npy \ +# --clusters_path /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_${run_id}_${keyword}_1000/vaevae_clusters.tsv \ +# --fasta /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/contigs_2kbp.fna.gz \ +# --rpkm /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/abundance.npz \ +# --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_out_32_no_predictor_reclustering_${run_id}_${keyword}_1000 \ +# --hmmout_path /home/projects/cpr_10006/projects/semi_vamb/data/marker_genes/markers_cami_${dataset}.hmmout \ +# --algorithm kmeans \ +# --minfasta 200000 diff --git a/workflow_vaevae/src/shortread_CAMI2_predictor.sh b/workflow_vaevae/src/shortread_CAMI2_predictor.sh index a228abd5..f6505e07 100755 --- a/workflow_vaevae/src/shortread_CAMI2_predictor.sh +++ b/workflow_vaevae/src/shortread_CAMI2_predictor.sh @@ -8,7 +8,7 @@ keyword=$3 # vamb \ --model taxonomy_predictor \ - --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_predictor_${keyword}_${run_id}_layers \ + --outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_predictor_${keyword}_${run_id}_abs_in \ --fasta /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/contigs_2kbp.fna.gz \ --rpkm /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/abundance.npz \ --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/${dataset}_taxonomy_${run_id}.tsv \ diff --git a/workflow_vaevae/src/shortread_marine_predictor.sh b/workflow_vaevae/src/shortread_marine_predictor.sh index 07c87cfb..2d479f37 100755 --- a/workflow_vaevae/src/shortread_marine_predictor.sh +++ b/workflow_vaevae/src/shortread_marine_predictor.sh @@ -6,7 +6,7 @@ vamb \ --outdir /home/projects/cpr_10006/people/svekut/cami2_marine_predictor_taxometer_${keyword} \ --fasta /home/projects/cpr_10006/data/CAMI2/marine/vamb_output/contigs_2kbp.fna \ --rpkm /home/projects/cpr_10006/data/CAMI2/marine/vamb_output/vambout/abundance.npz \ - --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/marine_taxonomy_taxometer_${keyword}.tsv \ + --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/marine_taxonomy_${keyword}.tsv \ -pe 100 \ -pq \ -pt 1024 \ diff --git a/workflow_vaevae/src/shortread_rhizo_predictor.sh b/workflow_vaevae/src/shortread_rhizo_predictor.sh index 7b66ed40..ee0209fd 100755 --- a/workflow_vaevae/src/shortread_rhizo_predictor.sh +++ b/workflow_vaevae/src/shortread_rhizo_predictor.sh @@ -6,7 +6,7 @@ vamb \ --outdir /home/projects/cpr_10006/people/svekut/cami2_rhizo_predictor_taxometer_${keyword} \ --fasta /home/projects/cpr_10006/data/CAMI2/rhizo/vamb_output/contigs_2kbp.fna \ --rpkm /home/projects/cpr_10006/data/CAMI2/rhizo/vamb_output/vambout/abundance.npz \ - --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/rhizo_taxonomy_taxometer_${keyword}.tsv \ + --taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/rhizo_taxonomy_${keyword}.tsv \ -pe 100 \ -pq \ -pt 1024 \