From d6977320ffab608bbf9ad929084bf6efd1030ce1 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 12 Sep 2023 16:18:34 +0200 Subject: [PATCH] Simplify get_latents function This function is only ever called with last_epoch=True, so just remove this argument. Also make it more type stable. --- vamb/aamb_encode.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/vamb/aamb_encode.py b/vamb/aamb_encode.py index 162150f4..db95561b 100644 --- a/vamb/aamb_encode.py +++ b/vamb/aamb_encode.py @@ -13,6 +13,7 @@ from collections.abc import Sequence from typing import Optional, IO, Union +from numpy.typing import NDArray ############################################################################# MODEL ########################################################### @@ -425,8 +426,8 @@ def trainmodel( ########### funciton that retrieves the clusters from Y latents def get_latents( - self, contignames: Sequence[str], data_loader, last_epoch: bool = True - ): + self, contignames: Sequence[str], data_loader + ) -> tuple[dict[str, set[str]], NDArray[np.float32]]: """Retrieve the categorical latent representation (y) and the contiouous latents (l) of the inputs Inputs: @@ -447,7 +448,7 @@ def get_latents( latent = np.empty((length, self.ld), dtype=np.float32) index_contigname = 0 row = 0 - clust_y_dict = dict() + clust_y_dict: dict[str, set[str]] = dict() Tensor = torch.cuda.FloatTensor if self.usecuda else torch.FloatTensor with torch.no_grad(): for depths_in, tnfs_in, _ in new_data_loader: @@ -473,23 +474,18 @@ def get_latents( depths_in = depths_in.cuda() tnfs_in = tnfs_in.cuda() - if last_epoch: - mu, _, _, _, y_sample = self(depths_in, tnfs_in)[0:5] - else: - y_sample = self(depths_in, tnfs_in)[4] + mu, _, _, _, y_sample = self(depths_in, tnfs_in)[0:5] if self.usecuda: Ys = y_sample.cpu().detach().numpy() - if last_epoch: - mu = mu.cpu().detach().numpy() - latent[row : row + len(mu)] = mu - row += len(mu) + mu = mu.cpu().detach().numpy() + latent[row : row + len(mu)] = mu + row += len(mu) else: Ys = y_sample.detach().numpy() - if last_epoch: - mu = mu.detach().numpy() - latent[row : row + len(mu)] = mu - row += len(mu) + mu = mu.detach().numpy() + latent[row : row + len(mu)] = mu + row += len(mu) del y_sample for _y in Ys: @@ -505,7 +501,4 @@ def get_latents( index_contigname += 1 del Ys - if last_epoch: - return clust_y_dict, latent - else: - return clust_y_dict + return clust_y_dict, latent