diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1897b8..e93fd58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/python/black - rev: 20.8b1 + rev: 22.3.0 hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 diff --git a/hmivae/ScModeDataloader.py b/hmivae/ScModeDataloader.py new file mode 100644 index 0000000..ddfa9c1 --- /dev/null +++ b/hmivae/ScModeDataloader.py @@ -0,0 +1,152 @@ +import numpy as np +import pandas as pd +import torch +from sklearn.preprocessing import OneHotEncoder, StandardScaler +from torch.utils.data import TensorDataset + + +def sparse_numpy_to_torch(adj_mat): + """Construct sparse torch tensor + Need to do csr -> coo + then follow https://stackoverflow.com/questions/50665141/converting-a-scipy-coo-matrix-to-pytorch-sparse-tensor + """ + adj_mat_coo = adj_mat.tocoo() + + values = adj_mat_coo.data + indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col)) + + i = torch.LongTensor(indices) + v = torch.FloatTensor(values) + shape = adj_mat_coo.shape + + return torch.sparse_coo_tensor(i, v, shape) + + +def get_n_cell_neighbours(adj_mat): + """Get the sum of a sparse matrix + Need to first replace all non-zero elements with 1 + Then add them up to get the number of neighbours + """ + adj_mat[adj_mat.nonzero()] = 1.0 + n_neighbours_sparse = adj_mat.sum(1) + + return np.asarray(n_neighbours_sparse) + + +class ScModeDataloader(TensorDataset): + def __init__(self, adata, scalers=None): + """ + Need to get the following from adata: + Y - NxP mean expression matrix + S - Nx(pC2) correlation matrix + M - Nx7 morphology matrix + scalers: set of data scalers + """ + self.adata = adata + Y = adata.X # per cell protein mean expression + S = adata.obsm["correlations"] + M = adata.obsm["morphology"] + weights = adata.obsm["weights"] + + self.n_cells = Y.shape[0] # number of cells + + if scalers is None: + self.scalers = {} + self.scalers["Y"] = StandardScaler().fit(Y) + self.scalers["S"] = StandardScaler().fit(S) + self.scalers["M"] = StandardScaler().fit(M) + + else: + self.scalers = scalers + + Y = self.scalers["Y"].transform(Y) + S = self.scalers["S"].transform(S) + M = self.scalers["M"].transform(M) + + self.Y = torch.tensor(Y).float() + self.S = torch.tensor(S).float() + self.M = torch.tensor(M).float() + self.C = self.get_spatial_context() + self.weights = torch.tensor( + weights + ).float() # these don't need to be scaled, not a data input + + self.samples_onehot = self.one_hot_encoding() + + if "background_covs" in adata.obsm.keys(): # dealing with background covariates + BKG = adata.obsm["background_covs"] + if scalers is None: + self.scalers["BKG"] = StandardScaler().fit(BKG) + BKG = self.scalers["BKG"].transform(BKG) + else: + BKG = self.scalers["BKG"].transform(BKG) + + self.BKG = torch.tensor(BKG).float() + else: + self.BKG = None + + def __len__(self): + return self.Y.shape[0] + + def one_hot_encoding(self, test=False): + """ + Creates a onehot encoding for samples. + """ + onehotenc = OneHotEncoder() + X = self.adata.obs[["Sample_name"]] + onehot_X = onehotenc.fit_transform(X).toarray() + + df = pd.DataFrame(onehot_X, columns=onehotenc.categories_[0]) + + df = df.reindex(columns=self.adata.obs.Sample_name.unique().tolist()) + + return torch.tensor(df.to_numpy()).float() + + def get_spatial_context(self): + """ + Multiplies the sparse neighbourhood matrix to protein mean expression (self.Y), + protein-protein correlation (self.S) and cell morphology (self.M) matrices. + The product-sum is normalized by the number of neighbours each cell has. + The resulting matrix, self.C, is the spatial context. + """ + adj_mat = sparse_numpy_to_torch( + self.adata.obsp["connectivities"] + ) # adjacency matrix + concatenated_features = torch.cat((self.Y, self.S, self.M), 1) + + n_cell_neighbours = get_n_cell_neighbours( + self.adata.copy().obsp["connectivities"] + ) + + unnormalized_C = torch.smm( + adj_mat, concatenated_features + ).to_dense() # unnormalized spatial context for each cell + + C = torch.div( + unnormalized_C, torch.tensor(n_cell_neighbours) + ) # normalize by number of adjacent cells + return C + + def __getitem__(self, idx): + + if self.BKG is None: + return ( + self.Y[idx, :], + self.S[idx, :], + self.M[idx, :], + self.C[idx, :], + self.samples_onehot[idx, :], + self.weights[idx, :], + idx, + ) + else: + return ( + self.Y[idx, :], + self.S[idx, :], + self.M[idx, :], + self.C[idx, :], + self.samples_onehot[idx, :], + self.weights[idx, :], + self.BKG[idx, :], + idx, + ) diff --git a/mypackage/__init__.py b/hmivae/__init__.py similarity index 100% rename from mypackage/__init__.py rename to hmivae/__init__.py diff --git a/hmivae/_hmivae_base_components.py b/hmivae/_hmivae_base_components.py new file mode 100644 index 0000000..757cfd4 --- /dev/null +++ b/hmivae/_hmivae_base_components.py @@ -0,0 +1,368 @@ +from typing import Literal, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class EncoderHMIVAE(nn.Module): + """Encoder for the case in which data is merged after initial encoding + input_exp_dim: Dimension for the original mean expression input + input_corr_dim: Dimension for the original correlations input + input_morph_dim: Dimension for the original morphology input + input_spcont_dim: Dimension for the original spatial context input + E_me: Dimension for the encoded mean expressions input + E_cr: Dimension for the encoded correlations input + E_mr: Dimension for the encoded morphology input + E_sc: Dimension for the encoded spatial context input + latent_dim: Dimension of the encoded output + E_cov: Dimension for the encoded covariates input + n_covariates: Number of covariates + n_hidden: Number of hidden layers, default=1 + leave_out_view: For ablation testing. View to leave out, default=None + """ + + def __init__( + self, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int, + E_cr: int, + E_mr: int, + E_sc: int, + latent_dim: int, + E_cov: Optional[int] = 10, + n_covariates: Optional[int] = 0, + n_hidden: Optional[int] = 1, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, + ): + super().__init__() + hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov + + self.leave_out_view = leave_out_view + + self.input_cov = nn.Linear(n_covariates, E_cov) + + self.input_exp = nn.Linear(input_exp_dim, E_me) + self.exp_hidden = nn.Linear(E_me, E_me) + + self.input_corr = nn.Linear(input_corr_dim, E_cr) + self.corr_hidden = nn.Linear(E_cr, E_cr) + self.input_morph = nn.Linear(input_morph_dim, E_mr) + self.morph_hidden = nn.Linear(E_mr, E_mr) + self.input_spatial_context = nn.Linear(input_spcont_dim, E_sc) + self.spatial_context_hidden = nn.Linear(E_sc, E_sc) + + self.linear = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] + ) + + self.mu_z = nn.Linear(hidden_dim, latent_dim) + self.std_z = nn.Linear(hidden_dim, latent_dim) + + def forward( + self, + x_mean: torch.Tensor, + x_correlations: torch.Tensor, + x_morphology: torch.Tensor, + x_spatial_context: torch.Tensor, + cov_list=torch.Tensor([]), + ): + + # included_views = [] + + # if self.leave_out_view is None: + + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) + + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + + if cov_list.shape[0] > 1: + h_cov = F.elu(self.input_cov(cov_list)) + else: + h_cov = cov_list + + h = torch.cat( + [h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1 + ).type_as(x_mean) + # else: + # if self.leave_out_view != "expression": + # h_mean = F.elu(self.input_exp(x_mean)) + # h_mean2 = F.elu(self.exp_hidden(h_mean)) + + # included_views.append(h_mean2) + + # if self.leave_out_view != "correlation": + # h_correlations = F.elu(self.input_corr(x_correlations)) + # h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + # included_views.append(h_correlations2) + + # if self.leave_out_view != "morphology": + # h_morphology = F.elu(self.input_morph(x_morphology)) + # h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + # included_views.append(h_morphology2) + + # if self.leave_out_view != "spatial": + # h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + # h_spatial_context2 = F.elu( + # self.spatial_context_hidden(h_spatial_context) + # ) + + # included_views.append(h_spatial_context2) + + # if cov_list.shape[0] > 1: + # h_cov = F.elu(self.input_cov(cov_list)) + # included_views.append(h_cov) + # else: + # h_cov = cov_list + # included_views.append(h_cov) + + # h = torch.cat(included_views, 1) # .type_as(x_mean) + + for net in self.linear: + h = F.elu(net(h)) + + mu_z = self.mu_z(h) + + log_std_z = self.std_z(h) + + return mu_z, log_std_z + + +class DecoderHMIVAE(nn.Module): + """ + Decoder for the case where data is merged after inital encoding + latent_dim: Dimension of the encoded input + E_me: Dimension for the encoded mean expressions input + E_cr: Dimension for the encoded correlations input + E_mr: Dimension for the encoded morphology input + E_sc: Dimension for the encoded spatial context input + input_exp_dim: Dimension for the decoded mean expression output + input_corr_dim: Dimension for the decoded correlations output + input_morph_dim: Dimension for the decoded morphology input + input_spcont_dim: Dimension for the decoded spatial context input + E_cov: Dimension for the encoded covariates input + n_covariates: Number of covariates + linear_decoder: True or False for using a linear decoder + n_hidden: Number of hidden layers, default=1 + leave_out_view: For ablation testing. View to leave out during training, default=None + """ + + def __init__( + self, + latent_dim: int, + E_me: int, + E_cr: int, + E_mr: int, + E_sc: int, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_cov: Optional[int] = 10, + n_covariates: Optional[int] = 0, + linear_decoder: Optional[bool] = False, + n_hidden: Optional[int] = 1, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, + ): + super().__init__() + hidden_dim = E_me + E_cr + E_mr + E_sc + E_cov + latent_dim = latent_dim + n_covariates + self.leave_out_view = leave_out_view + self.E_me = E_me + self.E_cr = E_cr + self.E_mr = E_mr + self.E_sc = E_sc + self.input = nn.Linear(latent_dim, hidden_dim) + self.linear_decoder = linear_decoder + self.linear = nn.ModuleList( + [nn.Linear(hidden_dim, hidden_dim) for i in range(n_hidden)] + ) + # mean expression + self.exp_hidden = nn.Linear(E_me, E_me) + self.mu_x_exp = nn.Linear(E_me, input_exp_dim) + self.std_x_exp = nn.Linear(E_me, input_exp_dim) + + # correlations/co-localizations + self.corr_hidden = nn.Linear(E_cr, E_cr) + self.mu_x_corr = nn.Linear(E_cr, input_corr_dim) + self.std_x_corr = nn.Linear(E_cr, input_corr_dim) + + # morphology + self.morph_hidden = nn.Linear(E_mr, E_mr) + self.mu_x_morph = nn.Linear(E_mr, input_morph_dim) + self.std_x_morph = nn.Linear(E_mr, input_morph_dim) + + # spatial context + self.spatial_context_hidden = nn.Linear(E_sc, E_sc) + self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim) + self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim) + + def forward(self, z, cov_list): + z_s = torch.cat( + [z, cov_list], 1 + ) # takes in one-hot as input, doesn't need to be symmetric with the encoder, doesn't output it + + if ( + self.linear_decoder + ): # linear decoder, no activation functions and single linear layer + out = self.input(z_s) + + h2_mean = self.exp_hidden(out[:, 0 : self.E_me]) + h2_correlations = self.corr_hidden( + out[:, self.E_me : self.E_me + self.E_cr] + ) + h2_morphology = self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + ) + h2_spatial_context = self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) + + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + + else: + # standard decoder with activation functions (non-linear) + out = F.elu(self.input(z_s)) + for net in self.linear: + out = F.elu(net(out)) + + # if self.leave_out_view is None: + + h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + h2_correlations = F.elu( + self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + ) + h2_morphology = F.elu( + self.morph_hidden( + out[:, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr] + ) + ) + h2_spatial_context = F.elu( + self.spatial_context_hidden( + out[ + :, + self.E_me + + self.E_cr + + self.E_mr : self.E_me + + self.E_cr + + self.E_mr + + self.E_sc, + ] + ) + ) + + # covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :] + + mu_x_exp = self.mu_x_exp(h2_mean) + std_x_exp = self.std_x_exp(h2_mean) + + # if self.use_weights: + # with torch.no_grad(): + # weights = self.get_corr_weights_per_cell( + # mu_x_exp.detach() + # ) # calculating correlation weights + # else: + # weights = None + + mu_x_corr = self.mu_x_corr(h2_correlations) + std_x_corr = self.std_x_corr(h2_correlations) + + mu_x_morph = self.mu_x_morph(h2_morphology) + std_x_morph = self.std_x_morph(h2_morphology) + + mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + return ( + mu_x_exp, + std_x_exp, + mu_x_corr, + std_x_corr, + mu_x_morph, + std_x_morph, + mu_x_spatial_context, + std_x_spatial_context, + # weights, + ) + + # else: + # included_views = [] + + # if self.leave_out_view != "expression": + # h2_mean = F.elu(self.exp_hidden(out[:, 0 : self.E_me])) + # mu_x_exp = self.mu_x_exp(h2_mean) + # std_x_exp = self.std_x_exp(h2_mean) + + # included_views.append(mu_x_exp) + # included_views.append(std_x_exp) + + # if self.leave_out_view != "correlation": + # h2_correlations = F.elu( + # self.corr_hidden(out[:, self.E_me : self.E_me + self.E_cr]) + # ) + # mu_x_corr = self.mu_x_corr(h2_correlations) + # std_x_corr = self.std_x_corr(h2_correlations) + + # included_views.append(mu_x_corr) + # included_views.append(std_x_corr) + + # if self.leave_out_view != "morphology": + # h2_morphology = F.elu( + # self.morph_hidden( + # out[ + # :, self.E_me + self.E_cr : self.E_me + self.E_cr + self.E_mr + # ] + # ) + # ) + # mu_x_morph = self.mu_x_morph(h2_morphology) + # std_x_morph = self.std_x_morph(h2_morphology) + + # included_views.append(mu_x_morph) + # included_views.append(std_x_morph) + + # if self.leave_out_view != "spatial": + # h2_spatial_context = F.elu( + # self.spatial_context_hidden( + # out[ + # :, + # self.E_me + # + self.E_cr + # + self.E_mr : self.E_me + # + self.E_cr + # + self.E_mr + # + self.E_sc, + # ] + # ) + # ) + # mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context) + # std_x_spatial_context = self.std_x_spcont(h2_spatial_context) + + # included_views.append(mu_x_spatial_context) + # included_views.append(std_x_spatial_context) + + # return included_views diff --git a/hmivae/_hmivae_model.py b/hmivae/_hmivae_model.py new file mode 100644 index 0000000..273847c --- /dev/null +++ b/hmivae/_hmivae_model.py @@ -0,0 +1,493 @@ +import inspect +import logging +from typing import List, Literal, Optional, Union + +import anndata as ad +import numpy as np +import pytorch_lightning as pl +import torch +from anndata import AnnData +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks.progress import RichProgressBar +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.trainer import Trainer +from scipy.stats.mstats import winsorize +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader + +# import hmivae +import hmivae._hmivae_module as module +import hmivae.ScModeDataloader as ScModeDataloader + +logger = logging.getLogger(__name__) + + +class hmivaeModel(pl.LightningModule): + """ + Skeleton for an scvi-tools model. + + Parameters + ---------- + adata + AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. + n_hidden + Number of nodes per hidden layer. + n_latent + Dimensionality of the latent space. + n_layers + Number of hidden layers used for encoder and decoder NNs. + **model_kwargs + Keyword args for :class:`~mypackage.MyModule` + Examples + -------- + >>> adata = anndata.read_h5ad(path_to_anndata) + >>> mypackage.MyModel.setup_anndata(adata, batch_key="batch") + >>> vae = mypackage.MyModel(adata) + >>> vae.train() + >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() + """ + + def __init__( + self, + adata: AnnData, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int = 32, + E_cr: int = 32, + E_mr: int = 32, + E_sc: int = 32, + E_cov: int = 10, + latent_dim: int = 10, + use_covs: bool = False, + use_weights: bool = True, + n_covariates: Optional[Union[None, int]] = None, + cohort: Optional[Union[None, str]] = None, + n_hidden: int = 1, + cofactor: float = 1.0, + beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", + linear_decoder: Optional[bool] = False, + batch_correct: bool = True, + is_trained_model: bool = False, + batch_size: Optional[int] = 1234, + random_seed: Optional[int] = 1234, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, + output_dir: str = ".", + **model_kwargs, + ): + # super(hmivaeModel, self).__init__(adata) + super().__init__() + + self.output_dir = output_dir + # self.adata = adata + self.use_covs = use_covs + self.use_weights = use_weights + self.leave_out_view = leave_out_view + self.is_trained_model = is_trained_model + self.random_seed = random_seed + self.name = f"{cohort}_rs{random_seed}_nh{n_hidden}_bs{batch_size}_hd{E_me}_ls{latent_dim}" + + if self.use_covs: + self.keys = [] + for key in adata.obsm.keys(): + # print(key) + if key not in ["correlations", "morphology", "spatial", "xy"]: + self.keys.append(key) + + if n_covariates is None: + raise ValueError("`n_covariates` cannot be None when `use_covs`==True") + else: + n_covariates = n_covariates + + # print("n_keys", len(self.keys)) + else: + self.keys = None + if n_covariates is None: + n_covariates = 0 + else: + n_covariates = 0 + print("`n_covariates` automatically set to 0 when use_covs == False") + + ( + self.train_batch, + self.test_batch, + n_samples, + self.features_config, + # self.cov_list, + ) = self.setup_anndata( + adata=adata, + protein_correlations_obsm_key="correlations", + cell_morphology_obsm_key="morphology", + continuous_covariate_keys=self.keys, + cofactor=cofactor, + image_correct=batch_correct, + batch_size=batch_size, + random_seed=random_seed, + ) + + n_covariates += n_samples + + print("n_covs", n_covariates) + + # for batch in self.train_batch: + # print('Y', torch.mean(batch[0],1)) + # print('S', torch.mean(batch[1],1)) + # print('M', torch.mean(batch[2],1)) + # print('C', torch.mean(batch[3],1)) + # print('one-hot', batch[4]) + # print('covariates', torch.mean(batch[5],1)) + # break + + # for batch in self.test_batch: + # print('Y_test', torch.mean(batch[0],1)) + # print('S_test', torch.mean(batch[1],1)) + # print('M_test', torch.mean(batch[2],1)) + # print('C_test', torch.mean(batch[3],1)) + # print('one-hot_test', batch[4]) + # print('covariates_test', torch.mean(batch[5],1)) + # break + + # print("cov_list", self.cov_list.shape) + + # print("self.adata", self.adata.X) + + # self.summary_stats provides information about anndata dimensions and other tensor info + self.module = module.hmiVAE( + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + E_cov=E_cov, + latent_dim=latent_dim, + n_covariates=n_covariates, + n_hidden=n_hidden, + use_covs=self.use_covs, + use_weights=self.use_weights, + linear_decoder=linear_decoder, + beta_scheme=beta_scheme, + batch_correct=batch_correct, + leave_out_view=leave_out_view, + **model_kwargs, + ) + self._model_summary_string = ( + "hmiVAE model with the following parameters: \n n_latent:{}, " + "n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}, " + "use_covariates:{} " + ).format( + latent_dim, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + use_covs, + ) + # necessary line to get params that will be used for saving/loading + self.init_params_ = self._get_init_params(locals()) + + logger.info("The model has been initialized") + + def _get_init_params(self, locals): + """ + Taken from: https://github.com/scverse/scvi-tools/blob/main/scvi/model/base/_base_model.py + """ + + init = self.__init__ + sig = inspect.signature(init) + parameters = sig.parameters.values() + + init_params = [p.name for p in parameters] + all_params = {p: locals[p] for p in locals if p in init_params} + + non_var_params = [p.name for p in parameters if p.kind != p.VAR_KEYWORD] + non_var_params = {k: v for (k, v) in all_params.items() if k in non_var_params} + var_params = [p.name for p in parameters if p.kind == p.VAR_KEYWORD] + var_params = {k: v for (k, v) in all_params.items() if k in var_params} + + user_params = {"kwargs": var_params, "non_kwargs": non_var_params} + + return user_params + + def train( + self, + max_epochs=15, + check_val_every_n_epoch=1, + config=None, + ): # misnomer, both train and test/val are here (either rename or separate) + + # with wandb.init(config=config): + # config=wandb.config + + pl.seed_everything(self.random_seed) + + early_stopping = EarlyStopping(monitor="recon_lik_test", mode="max", patience=1) + + cb_chkpt = ModelCheckpoint( + dirpath=f"{self.output_dir}", + monitor="recon_lik_test", + mode="max", + save_top_k=1, + filename="{epoch}_{step}_{recon_lik_test:.3f}", + ) + + cb_progress = RichProgressBar() + # wandb.finish() + # wandb_logger = WandbLogger(log_model="all") + + if self.leave_out_view is None: + + wandb_logger = WandbLogger( + project="hmivae_hyperparameter_runs", + name=self.name, + config=self.features_config, + ) + else: + wandb_logger = WandbLogger( + project="hmivae_ablation", config=self.features_config + ) + + trainer = Trainer( + max_epochs=max_epochs, + check_val_every_n_epoch=check_val_every_n_epoch, + callbacks=[early_stopping, cb_progress, cb_chkpt], + logger=wandb_logger, + # overfit_batches=0.01, + gradient_clip_val=2.0, + accelerator="auto", + devices="auto", + log_every_n_steps=1, + # limit_train_batches=0.1, + # limit_val_batches=0.1, + ) + + trainer.fit(self.module, self.train_batch, self.test_batch) + + # wandb.finish() + + @torch.no_grad() + def get_latent_representation( + self, + adata: AnnData, + protein_correlations_obsm_key: str, + cell_morphology_obsm_key: str, + continuous_covariate_keys: Optional[List[str]] = None, + cofactor: float = 1.0, + is_trained_model: Optional[bool] = False, + batch_correct: Optional[bool] = True, + use_covs: Optional[bool] = True, + save_view_specific_embeddings: Optional[bool] = True, + ) -> AnnData: + """ + Gives the latent representation of each cell. + """ + if is_trained_model: + ( + adata_train, + adata_test, + data_train, + data_test, + n_covariates, + # cat_list, + # train_idx, + # test_idx, + ) = self.setup_anndata( + adata, + protein_correlations_obsm_key, + cell_morphology_obsm_key, + continuous_covariate_keys=continuous_covariate_keys, + cofactor=cofactor, + is_trained_model=is_trained_model, + image_correct=batch_correct, + ) + + if save_view_specific_embeddings: + ( + adata_train.obsm["VAE"], + adata_train.obsm["expression_embedding"], + adata_train.obsm["correlation_embedding"], + adata_train.obsm["morphology_embedding"], + adata_train.obsm["spatial_context_embedding"], + ) = self.module.inference( + data_train, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=train_idx) + + ( + adata_test.obsm["VAE"], + adata_test.obsm["expression_embedding"], + adata_test.obsm["correlation_embedding"], + adata_test.obsm["morphology_embedding"], + adata_test.obsm["spatial_context_embedding"], + ) = self.module.inference( + data_test, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=test_idx) + + else: + adata_train.obsm["VAE"] = self.module.inference( + data_train, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=train_idx) + adata_test.obsm["VAE"] = self.module.inference( + data_test, + n_covariates=n_covariates, + use_covs=use_covs, + batch_correct=batch_correct, + ) # idx=test_idx) + + return ad.concat([adata_train, adata_test], uns_merge="first") + else: + raise Exception( + "No latent representation to produce! Model is not trained!" + ) + + # @setup_anndata_dsp.dedent + @staticmethod + def setup_anndata( + # self, + adata: AnnData, + protein_correlations_obsm_key: str, + cell_morphology_obsm_key: str, + # cell_spatial_context_obsm_key: str, + protein_correlations_names_uns_key: Optional[str] = None, + cell_morphology_names_uns_key: Optional[str] = None, + image_correct: bool = True, + batch_size: Optional[int] = 4321, + batch_key: Optional[str] = None, + labels_key: Optional[str] = None, + layer: Optional[str] = None, + categorical_covariate_keys: Optional[List[str]] = None, + continuous_covariate_keys: Optional[ + List[str] + ] = None, # obsm keys for other categories + cofactor: float = 1.0, + train_prop: Optional[float] = 0.75, + apply_winsorize: Optional[bool] = True, + arctanh_corrs: Optional[bool] = False, + is_trained_model: Optional[bool] = False, + random_seed: Optional[int] = 1234, + copy: bool = False, + ) -> Optional[AnnData]: + """ + %(summary)s. + Takes in an AnnData object and returns the train and test loaders. + Parameters + ---------- + %(param_adata)s + %(param_batch_key)s + %(param_labels_key)s + %(param_layer)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + %(param_copy)s + + Returns + ------- + %(returns)s + """ + N_PROTEINS = adata.shape[1] + N_MORPHOLOGY = len(adata.uns["names_morphology"]) + + # print("adata in setup_adata", adata.X) + + if continuous_covariate_keys is not None: + cat_list = [] + for cat_key in continuous_covariate_keys: + # print(cat_key) + # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) + category = adata.obsm[cat_key] + cat_list.append(category) + cat_list = np.concatenate(cat_list, 1) + n_cats = cat_list.shape[1] + # if apply_winsorize: + # for i in range(cat_list.shape[1]): + # cat_list[:, i] = winsorize(cat_list[:, i], limits=[0, 0.01]) + + adata.obsm["background_covs"] = cat_list + else: + n_cats = 0 + + adata.X = np.arcsinh(adata.X / cofactor) + + if apply_winsorize: + for i in range(N_PROTEINS): + adata.X[:, i] = winsorize(adata.X[:, i], limits=[0, 0.01]) + for i in range(N_MORPHOLOGY): + adata.obsm[cell_morphology_obsm_key][:, i] = winsorize( + adata.obsm[cell_morphology_obsm_key][:, i], limits=[0, 0.01] + ) + + if arctanh_corrs: + adata.obsm[protein_correlations_obsm_key] = np.arctanh( + adata.obsm[protein_correlations_obsm_key] + ) + + samples_list = ( + adata.obs["Sample_name"].unique().tolist() + ) # samples in the adata + + samples_train, samples_test = train_test_split( + samples_list, train_size=train_prop, random_state=random_seed + ) + adata_train = adata.copy()[adata.obs["Sample_name"].isin(samples_train), :] + adata_test = adata.copy()[adata.obs["Sample_name"].isin(samples_test), :] + + data_train = ScModeDataloader.ScModeDataloader(adata_train) + data_test = ScModeDataloader.ScModeDataloader(adata_test, data_train.scalers) + + features_ranges = { + "Train expression min/max": (data_train.Y.min(), data_train.Y.max()), + "Train correlation min/max": (data_train.S.min(), data_train.S.max()), + "Train morphology min/max": (data_train.M.min(), data_train.M.max()), + "Train spatial context min/max": (data_train.C.min(), data_train.C.max()), + "Test expression min/max": (data_test.Y.min(), data_test.Y.max()), + "Test correlation min/max": (data_test.S.min(), data_test.S.max()), + "Test morphology min/max": (data_test.M.min(), data_test.M.max()), + "Test spatial context min/max": (data_test.C.min(), data_test.C.max()), + } + + loader_train = DataLoader( + data_train, batch_size=batch_size, shuffle=True, num_workers=64 + ) + loader_test = DataLoader( + data_test, batch_size=batch_size, num_workers=64 + ) # shuffle=True) + + if image_correct: + n_samples = len(samples_train) + # print("n_samples", n_samples) + # print("cat_list", cat_list.shape) + print("one-hot+covs", n_samples + n_cats) + else: + n_samples = 0 + print("n_cats", n_cats) + + if is_trained_model: + return ( + adata_train, + adata_test, + data_train, + data_test, + n_cats + n_samples, + ) + + else: + + return ( + loader_train, + loader_test, + n_samples, + features_ranges, + ) diff --git a/hmivae/_hmivae_module.py b/hmivae/_hmivae_module.py new file mode 100644 index 0000000..34784a5 --- /dev/null +++ b/hmivae/_hmivae_module.py @@ -0,0 +1,755 @@ +from typing import Literal, Optional, Sequence, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F + +# import hmivae +from hmivae._hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE + +# from pytorch_lightning.callbacks import Callback + + +# from anndata import AnnData + +torch.backends.cudnn.benchmark = True + + +class hmiVAE(pl.LightningModule): + """ + Variational Autoencoder for hmiVAE based on pytorch-lightning. + """ + + def __init__( + self, + input_exp_dim: int, + input_corr_dim: int, + input_morph_dim: int, + input_spcont_dim: int, + E_me: int = 32, + E_cr: int = 32, + E_mr: int = 32, + E_sc: int = 32, + E_cov: int = 10, + latent_dim: int = 10, + n_covariates: int = 0, + leave_out_view: Optional[ + Union[None, Literal["expression", "correlation", "morphology", "spatial"]] + ] = None, + use_covs: bool = False, + use_weights: bool = True, + linear_decoder: Optional[bool] = False, + n_hidden: int = 1, + beta_scheme: Optional[Literal["constant", "warmup"]] = "warmup", + batch_correct: bool = True, + n_steps_kl_warmup: Union[int, None] = None, + n_epochs_kl_warmup: Union[int, None] = 10, + ): + super().__init__() + # hidden_dim = E_me + E_cr + E_mr + E_sc + self.n_steps_kl_warmup = n_steps_kl_warmup + self.n_epochs_kl_warmup = n_epochs_kl_warmup + self.n_covariates = n_covariates + + # self.cat_list = cat_list + + self.batch_correct = batch_correct + + self.use_covs = use_covs + + self.use_weights = use_weights + + self.leave_out_view = leave_out_view + + self.beta_scheme = beta_scheme + + self.encoder = EncoderHMIVAE( + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + E_me, + E_cr, + E_mr, + E_sc, + latent_dim, + E_cov=E_cov, + n_covariates=n_covariates, + leave_out_view=leave_out_view, + n_hidden=n_hidden, + ) + + self.decoder = DecoderHMIVAE( + latent_dim, + E_me, + E_cr, + E_mr, + E_sc, + input_exp_dim, + input_corr_dim, + input_morph_dim, + input_spcont_dim, + E_cov=E_cov, + n_covariates=n_covariates, + leave_out_view=leave_out_view, + n_hidden=n_hidden, + linear_decoder=linear_decoder, + ) + + self.save_hyperparameters(ignore=["adata"]) + + def reparameterization(self, mu, log_std): + std = torch.exp(log_std) + eps = torch.randn_like(log_std) + + # sampling from encoded distribution + z_samples = mu + eps * std + + return z_samples + + def KL_div(self, enc_x_mu, enc_x_logstd, z): + """Takes in the encoded x mu and sigma, and the z sampled from + q, and outputs the KL-Divergence term in ELBO""" + + p = torch.distributions.Normal( + torch.zeros_like(enc_x_mu), torch.ones_like(enc_x_logstd) + ) + enc_x_std = torch.exp(enc_x_logstd) + q = torch.distributions.Normal(enc_x_mu, enc_x_std + 1e-6) + + log_q_zx = q.log_prob(z) + log_p_z = p.log_prob(z) + + kl = log_q_zx - log_p_z + kl = kl.sum(-1) + + return kl + + def compute_kl_weight( + self, + epoch: int, + step: Optional[int], + n_epochs_kl_warmup: Optional[int], + n_steps_kl_warmup: Optional[int], + max_kl_weight: float = 1.0, + min_kl_weight: float = 0.0, + ) -> float: + """ + Compute the weight for the KL-Div term in loss. + Taken from scVI: + https://github.com/scverse/scvi-tools/blob/2c22bda9bcfb5a89d62c96c4ad39d8a1e297eb08/scvi/train/_trainingplans.py#L31 + """ + slope = max_kl_weight - min_kl_weight + + if min_kl_weight > max_kl_weight: + raise ValueError( + f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}" + ) + + if n_epochs_kl_warmup: + if epoch < n_epochs_kl_warmup: + return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight + elif n_steps_kl_warmup: + if step < n_steps_kl_warmup: + return slope * (step / n_steps_kl_warmup) + min_kl_weight + + return max_kl_weight + + def em_recon_loss( + self, + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + y, + s, + m, + c, + weights: Optional[Union[None, torch.Tensor]] = None, + ): + """Takes in the parameters output from the decoder, + and the original input x, and gives the reconstruction + loss term in ELBO + dec_x_mu_exp: torch.Tensor, decoded means for protein expression feature + dec_x_logstd_exp: torch.Tensor, decoded log std for protein expression feature + dec_x_mu_corr: torch.Tensor, decoded means for correlation feature + dec_x_logstd_corr: torch.Tensor, decoded log std for correlations feature + dec_x_mu_morph: torch.Tensor, decoded means for morphology feature + dec_x_logstd_morph: torch.Tensor, decoded log std for morphology feature + dec_x_mu_spcont: torch.Tensor, decoded means for spatial context feature + dec_x_logstd_spcont: torch.Tensor, decoded log std for spatial context feature + y: torch.Tensor, original mean expression input + s: torch.Tensor, original correlation input + m: torch.Tensor, original morphology input + c: torch.Tensor, original cell context input + weights: torch.Tensor, weights calculated from decoded means for protein expression feature + """ + + ## Mean expression + dec_x_std_exp = torch.exp(dec_x_logstd_exp) + p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6) + log_p_xz_exp = p_rec_exp.log_prob(y) + log_p_xz_exp = log_p_xz_exp.sum(-1) + + ## Correlations + dec_x_std_corr = torch.exp(dec_x_logstd_corr) + p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6) + # log_p_xz_corr = p_rec_corr.log_prob(s) + if weights is None: + log_p_xz_corr = p_rec_corr.log_prob(s) + else: + log_p_xz_corr = torch.mul( + weights, p_rec_corr.log_prob(s) + ) # does element-wise multiplication + log_p_xz_corr = log_p_xz_corr.sum(-1) + + ## Morphology + dec_x_std_morph = torch.exp(dec_x_logstd_morph) + p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6) + log_p_xz_morph = p_rec_morph.log_prob(m) + log_p_xz_morph = log_p_xz_morph.sum(-1) + + ## Spatial context + dec_x_std_spcont = torch.exp(dec_x_logstd_spcont) + p_rec_spcont = torch.distributions.Normal( + dec_x_mu_spcont, dec_x_std_spcont + 1e-6 + ) + log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix + log_p_xz_spcont = log_p_xz_spcont.sum(-1) + + return ( + log_p_xz_exp, + log_p_xz_corr, + log_p_xz_morph, + log_p_xz_spcont, + ) + + def neg_ELBO( + self, + enc_x_mu, + enc_x_logstd, + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + z, + y, + s, + m, + c, + weights: Optional[Union[None, torch.Tensor]] = None, + ): + kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z) + + ( + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) = self.em_recon_loss( + dec_x_mu_exp, + dec_x_logstd_exp, + dec_x_mu_corr, + dec_x_logstd_corr, + dec_x_mu_morph, + dec_x_logstd_morph, + dec_x_mu_spcont, + dec_x_logstd_spcont, + y, + s, + m, + c, + weights, + ) + return ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) + + def loss(self, kl_div, recon_loss, beta: float = 1.0): + + return beta * kl_div.mean() - recon_loss.mean() + + def training_step( + self, + train_batch, + recon_weights=np.array([1.0, 1.0, 1.0, 1.0]), + ): + """ + Carries out the training step. + train_batch: torch.Tensor. Training data, + spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information, + corr_weights: numpy.array. Array with weights for the correlations for each cell. + recon_weights: numpy.array. Array with weights for each view during loss calculation. + beta: float. Coefficient for KL-Divergence term in ELBO. + """ + + Y = train_batch[0] + S = train_batch[1] + M = train_batch[2] + spatial_context = train_batch[3] + # batch_idx = train_batch[-1] + if self.use_weights: + weights = train_batch[5] + else: + weights = None + + if self.use_covs: + categories = train_batch[6] + else: + categories = torch.Tensor([]).type_as(Y) + + if self.batch_correct: + one_hot = train_batch[4] + + cov_list = torch.cat([one_hot, categories], 1).float().type_as(Y) + elif self.use_covs: + cov_list = categories + else: + cov_list = torch.Tensor([]).type_as(Y) + + mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list) + + z_samples = self.reparameterization(mu_z, log_std_z) + + # decoding + ( + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + ) = self.decoder(z_samples, cov_list) + + ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) = self.neg_ELBO( + mu_z, + log_std_z, + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + z_samples, + Y, + S, + M, + spatial_context, + weights, + ) + + if self.beta_scheme == "warmup": + + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + else: + beta = 1.0 + + # print('beta=', beta) + + if self.leave_out_view is not None: + if self.leave_out_view == "expression": + recon_weights = np.array([0.0, 1.0, 1.0, 1.0]) + if self.leave_out_view == "correlation": + recon_weights = np.array([1.0, 0.0, 1.0, 1.0]) + if self.leave_out_view == "morphology": + recon_weights = np.array([1.0, 1.0, 0.0, 1.0]) + if self.leave_out_view == "spatial": + recon_weights = np.array([1.0, 1.0, 1.0, 0.0]) + else: + recon_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + recon_loss = ( + recon_weights[0] * recon_lik_me + + recon_weights[1] * recon_lik_corr + + recon_weights[2] * recon_lik_mor + + recon_weights[3] * recon_lik_sc + ) + + loss = self.loss(kl_div, recon_loss, beta=beta) + + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("beta", beta, on_step=True, on_epoch=True, prog_bar=False) + self.log( + "kl_div", kl_div.mean().item(), on_step=True, on_epoch=True, prog_bar=False + ) + self.log( + "recon_lik", + recon_loss.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_me", + recon_lik_me.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_corr", + recon_lik_corr.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_mor", + recon_lik_mor.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_sc", + recon_lik_sc.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + + return { + "loss": loss, + "kl_div": kl_div.mean().item(), + "recon_lik": recon_loss.mean().item(), + "recon_lik_me": recon_lik_me.mean().item(), + "recon_lik_corr": recon_lik_corr.mean().item(), + "recon_lik_mor": recon_lik_mor.mean().item(), + "recon_lik_sc": recon_lik_sc.mean().item(), + } + + def validation_step( + self, + test_batch, + n_other_cat: int = 0, + L_iter: int = 300, + ): + """---> Add random one-hot encoding + Carries out the validation/test step. + test_batch: torch.Tensor. Validation/test data, + spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information, + corr_weights: numpy.array. Array with weights for the correlations for each cell. + recon_weights: numpy.array. Array with weights for each view during loss calculation. + beta: float. Coefficient for KL-Divergence term in ELBO. + """ + Y = test_batch[0] + S = test_batch[1] + M = test_batch[2] + spatial_context = test_batch[3] + batch_idx = test_batch[-1] + + if self.use_weights: + weights = test_batch[5] + else: + weights = None + + if self.use_covs: + categories = test_batch[6] + n_classes = self.n_covariates - categories.shape[1] + else: + categories = torch.Tensor([]).type_as(Y) + n_classes = self.n_covariates + + test_loss = torch.empty(size=[len(batch_idx), n_classes]) + elbo_full = torch.empty(size=[len(batch_idx), n_classes]) + + # for i in range(L_iter): + for i in range(n_classes): + + if self.batch_correct: + # one_hot = self.random_one_hot( + # n_classes=n_classes, n_samples=len(batch_idx) + # ).type_as(Y) + + one_hot_zeros = torch.zeros(size=[1, n_classes]) + + one_hot_zeros[0, i] = 1.0 + + one_hot = one_hot_zeros.repeat((len(batch_idx), 1)).type_as(Y) + + cov_list = torch.cat([one_hot, categories], 1).float().type_as(Y) + elif self.use_covs: + cov_list = categories + else: + cov_list = torch.Tensor([]).type_as(Y) + + mu_z, log_std_z = self.encoder( + Y, S, M, spatial_context, cov_list + ) # valid step + + z_samples = self.reparameterization(mu_z, log_std_z) + + # decoding + ( + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + ) = self.decoder(z_samples, cov_list) + + ( + kl_div, + recon_lik_me, + recon_lik_corr, + recon_lik_mor, + recon_lik_sc, + ) = self.neg_ELBO( + mu_z, + log_std_z, + mu_x_exp_hat, + log_std_x_exp_hat, + mu_x_corr_hat, + log_std_x_corr_hat, + mu_x_morph_hat, + log_std_x_morph_hat, + mu_x_spcont_hat, + log_std_x_spcont_hat, + z_samples, + Y, + S, + M, + spatial_context, + weights, + ) + + if self.beta_scheme == "warmup": + + beta = self.compute_kl_weight( + self.current_epoch, + self.global_step, + self.n_epochs_kl_warmup, + self.n_steps_kl_warmup, + ) + else: + beta = 1.0 + + if self.leave_out_view is not None: + if self.leave_out_view == "expression": + recon_weights = np.array([0.0, 1.0, 1.0, 1.0]) + if self.leave_out_view == "correlation": + recon_weights = np.array([1.0, 0.0, 1.0, 1.0]) + if self.leave_out_view == "morphology": + recon_weights = np.array([1.0, 1.0, 0.0, 1.0]) + if self.leave_out_view == "spatial": + recon_weights = np.array([1.0, 1.0, 1.0, 0.0]) + else: + recon_weights = np.array([1.0, 1.0, 1.0, 1.0]) + + recon_loss = ( + recon_weights[0] * recon_lik_me + + recon_weights[1] * recon_lik_corr + + recon_weights[2] * recon_lik_mor + + recon_weights[3] * recon_lik_sc + ) + + loss = self.loss(kl_div, recon_loss, beta=beta) + + full_elbo = recon_loss.mean() - kl_div.mean() + + test_loss[:, i] = loss + + elbo_full[:, i] = full_elbo + + self.log( + "test_loss", + # sum(test_loss) / L_iter, + test_loss.mean(1).mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + ) # log the average test loss over all the iterations + self.log( + "test_full_elbo", + elbo_full.mean(1).mean(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "kl_div_test", + kl_div.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log("beta_test", beta, on_step=True, on_epoch=True, prog_bar=False) + self.log( + "recon_lik_test", + recon_loss.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_me_test", + recon_lik_me.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_corr_test", + recon_lik_corr.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_mor_test", + recon_lik_mor.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + self.log( + "recon_lik_sc_test", + recon_lik_sc.mean().item(), + on_step=True, + on_epoch=True, + prog_bar=False, + ) + + return { + "loss": test_loss.mean(1).mean(), + "kl_div": kl_div.mean().item(), + "recon_lik": recon_loss.mean().item(), + "recon_lik_me": recon_lik_me.mean().item(), + "recon_lik_corr": recon_lik_corr.mean().item(), + "recon_lik_mor": recon_lik_mor.mean().item(), + "recon_lik_sc": recon_lik_sc.mean().item(), + } + + def configure_optimizers(self): + """Optimizer""" + parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) + optimizer = torch.optim.Adam(parameters, lr=1e-3) + return optimizer + + @torch.no_grad() + def get_input_embeddings( + self, x_mean, x_correlations, x_morphology, x_spatial_context + ): + """ + Returns the view-specific embeddings. + """ + h_mean = F.elu(self.input_exp(x_mean)) + h_mean2 = F.elu(self.exp_hidden(h_mean)) + + h_correlations = F.elu(self.input_corr(x_correlations)) + h_correlations2 = F.elu(self.corr_hidden(h_correlations)) + + h_morphology = F.elu(self.input_morph(x_morphology)) + h_morphology2 = F.elu(self.morph_hidden(h_morphology)) + + h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context)) + h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context)) + + return h_mean2, h_correlations2, h_morphology2, h_spatial_context2 + + @torch.no_grad() + def inference( + self, + data, + n_covariates: int, + use_covs: bool = True, + batch_correct: bool = True, + indices: Optional[Sequence[int]] = None, + give_mean: bool = True, + # idx = None, + ) -> np.ndarray: + """ + Return the latent representation of each cell. + """ + # if self.leave_out_view is None: + Y = data.Y + S = data.S + M = data.M + C = data.C + # batch_idx = idx + # print(batch_idx) + # if self.use_covs: + # categories = data.BKG + # n_cats = categories.shape[1] + # else: + # categories = torch.Tensor([]) + # n_cats = 0 + if use_covs: + categories = data.BKG + n_classes = n_covariates - categories.shape[1] + else: + categories = torch.Tensor([]).type_as(Y) + n_classes = n_covariates + + if batch_correct: + one_hot = self.random_one_hot( + n_classes=n_classes, n_samples=Y.shape[0] + ).type_as(Y) + # one_hot = data.samples_onehot + # if one_hot.shape[1] < self.n_covariates - n_cats: + # zeros_pad = torch.Tensor( + # np.zeros( + # [ + # one_hot.shape[0], + # (self.n_covariates - n_cats) - one_hot.shape[1], + # ] + # ) + # ) + # one_hot = torch.cat([one_hot, zeros_pad], 1) + # else: + # one_hot = one_hot + + cov_list = torch.cat([one_hot, categories], 1).float() + else: + cov_list = torch.Tensor([]) + + if give_mean: + mu_z, _ = self.encoder(Y, S, M, C, cov_list) + + return mu_z.numpy() + else: + mu_z, log_std_z = self.encoder(Y, S, M, C, cov_list) + + z = self.reparameterization(mu_z, log_std_z) + + return z.numpy() + + @torch.no_grad() + def random_one_hot(self, n_classes: int, n_samples: int): + """ + Generates a random one hot encoded matrix. + From: https://stackoverflow.com/questions/45093615/random-one-hot-matrix-in-numpy + """ + # x = np.eye(n_classes) + return torch.Tensor(np.eye(n_classes)[np.random.choice(n_classes, n_samples)]) diff --git a/mypackage/_mypyromodel.py b/hmivae/_mypyromodel.py similarity index 100% rename from mypackage/_mypyromodel.py rename to hmivae/_mypyromodel.py diff --git a/mypackage/_mypyromodule.py b/hmivae/_mypyromodule.py similarity index 100% rename from mypackage/_mypyromodule.py rename to hmivae/_mypyromodule.py diff --git a/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py new file mode 100644 index 0000000..20f749c --- /dev/null +++ b/hmivae/clinical_associations_latent_dim_and_cluster_prevalence.py @@ -0,0 +1,432 @@ +### Clinical associations + +from collections import Counter + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scanpy as sc +import statsmodels.api as sm +import tifffile +from rich.progress import ( + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) + +### Load data + +cohort = "basel" + +adata = sc.read_h5ad( + f"../cluster_analysis/{cohort}/best_run_{cohort}_no_dna/results_diff_res/{cohort}_adata_new.h5ad" +) # directory where adata was stored + +patient_data = pd.read_csv( + f"{cohort}/{cohort}_survival_patient_samples.tsv", sep="\t", index_col=0 +) + +clinical_variables = [ + "ERStatus", + "grade", + "PRStatus", + "HER2Status", + "Subtype", + "clinical_type", + "HR", +] # changes for each cohort, here example is basel + +patient_col = "PID" + +cluster_col = "leiden" + +### Visualize the data + +plt.rcParams["figure.figsize"] = [10, 10] + +for n, i in enumerate(clinical_variables): + ax = plt.subplot(4, 2, n + 1) + df = pd.DataFrame(patient_data[i].value_counts()).transpose() + + df.plot.bar(ax=ax) + + ax.set_xticklabels([i], rotation=0) + plt.legend(bbox_to_anchor=[1.0, 1.1]) + +# Patient / Latent Variable associations + +df = pd.DataFrame( + columns=["Sample_name"] + + [f"median_latent_dim_{n}" for n in range(adata.obsm["VAE"].shape[1])] +) + +for n, sample in enumerate(adata.obs.Sample_name.unique()): + sample_adata = adata.copy()[adata.obs.Sample_name.isin([sample]), :] + + df.loc[str(n)] = [sample] + np.median(sample_adata.obsm["VAE"], axis=0).tolist() + +patient_latent = pd.merge(df, patient_data, on="Sample_name") + +## first try + +latent_dim_cols = [i for i in patient_latent.columns if "median" in i] +exception_variables = [] + +dfs = [] + +# cvar is clinical variable +# sub_cvar are the values the clincal variable can take on e.g. for cvar == ERStatus, sub_cvar == pos or sub_cvar == neg + +for cvar in clinical_variables: # using all latent dims for this pass + cvar_dfs = [] + + for sub_cvar in patient_latent[cvar].unique(): + print(cvar, sub_cvar) + sub_cvar_df = pd.DataFrame({}) + selected_df = patient_latent.copy()[ + ~patient_latent[cvar].isna() + ] # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) # map 1 and 0 for entries that belong to the sub_cvar + + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + X = sm.add_constant(X) # add constant + y = selected_df[ # this is the 0 and 1 col + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model + + sub_cvar_df["latent_dim"] = [c.split("_")[-1] for c in latent_dim_cols] + + sub_cvar_df["tvalues"] = log_reg.tvalues[1:] # remove the constant + + sub_cvar_df["clinical_variable"] = [ + f"{cvar}:{sub_cvar}" + ] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + except Exception as e: + exception_variables.append((cvar, sub_cvar)) + print(f"{cvar}:{sub_cvar} had an exception occur: {e}") + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) + +## Second try, which features caused issues for which clinical variable + +features_to_remove = [] +# cvar_dfs2 = [] +for cvar, sub_cvar in exception_variables: + # print(cvar, sub_cvar) + selected_df = patient_latent[ + ~patient_latent[cvar].isna() + ].copy() # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + y = selected_df[cvar].to_numpy() + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + + perf_sep_features = [] + for i in range(X.shape[1]): # introduce each latent dim one at a time to see which caused issues + X_1 = X.copy()[:, 0 : i + 1] + X_1 = sm.add_constant(X_1) # add constant + try: + log_reg = sm.Logit(y, X_1).fit() # fit the Logistic Regression model + print( + f"Completed: tvalues for {cvar}:{sub_cvar}, features till {i} -> {log_reg.tvalues}" + ) + # print(log_reg.summary()) + except Exception as e: + print(f"{cvar}:{sub_cvar} for feature {i} has exception: {e}") + perf_sep_features.append(i) # store the issue causing latent dim + + # if len(perf_sep_features) == 0: + # sub_cvar_df = pd.DataFrame({}) + # sub_cvar_df['latent_dim'] = [c.split('_')[-1] for c in latent_dim_cols] + + # assert len(log_reg.tvalues) == X.shape[1]+1 #for constant -- check this is the last one + + # sub_cvar_df['tvalues'] = log_reg.tvalues[1:] # remove the constant -- this should be the last one + + # sub_cvar_df['clinical_variable'] = [f"{cvar}:{sub_cvar}"]*sub_cvar_df.shape[0] + + # cvar_dfs2.append(sub_cvar_df) # this will often turn out to be empty since if it gave issues before, it should give issues now + + # else: + + features_to_remove.append((cvar, sub_cvar, perf_sep_features)) + +## final try, remove the features causing issues and store their t-value as NaN + +sub_cvars = [] + +for cvar, sub_cvar, del_inds in features_to_remove: + selected_df = patient_latent[ + ~patient_latent[cvar].isna() + ].copy() # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + y = selected_df[cvar].to_numpy() + X = selected_df[ + latent_dim_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + del_inds = del_inds + X = np.delete(X, del_inds, axis=1) # delete the issue causing latent dims from full set + print(X.shape) + X = sm.add_constant(X) # add constant + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model on the remaining latent dims + print( + f"Completed: tvalues for {cvar}:{sub_cvar}, features till {i} -> {log_reg.tvalues}" + ) + + sub_cvar_df = pd.DataFrame({}) + sub_cvar_df["latent_dim"] = [c.split("_")[-1] for c in latent_dim_cols] + + tvalues = log_reg.tvalues[1:].tolist() # + [np.nan] + + for i in del_inds: # for latent dims that caused issues, store their tvalues as nan so we know which ones didn't work + if i > len(tvalues): + tvalues = np.insert(tvalues, i - 1, np.nan) + else: + tvalues = np.insert(tvalues, i, np.nan) + + # tvalues = np.insert(tvalues, del_inds.remove(19), np.nan) + # assert len(log_reg.tvalues) == X.shape[1]+1 #for constant -- check this is the last one + + sub_cvar_df["tvalues"] = tvalues + + sub_cvar_df["clinical_variable"] = [f"{cvar}:{sub_cvar}"] * sub_cvar_df.shape[0] + + sub_cvars.append(sub_cvar_df) + except Exception as e: + print(f"{cvar}:{sub_cvar} for feature {i} has exception: {e}") + +sub_cvar_df1 = pd.concat(sub_cvars) + +full_clin_df = pd.concat(dfs).reset_index(drop=True) + +final_full_clin_df = pd.concat([full_clin_df, sub_cvar_df1]).reset_index(drop=True) + +final_full_clin_df = pd.pivot_table( + final_full_clin_df, + index="clinical_variable", + values="tvalues", + columns="latent_dim", +) # df that's plotted + + +# Patient / Cluster associations +# First we need to define cluster prevalance within a patient. Doing this in two ways: +# 1. How we were doing it before -- proportion of cells in patient x that belong to cluster c +# 2. Cells of cluster c per mm^2 of tissue + +clusters_patient = pd.merge( + adata.obs.reset_index()[["Sample_name", "leiden", "cell_id"]], + patient_data.reset_index(), + on="Sample_name", +) + +## Option 1: Proportion of cells in patient x that belong in cluster c + +hi_or_low = clusters_patient[[patient_col, cluster_col]] + +## Proportion of cells belonging to each cluster for each image / patient + +hi_or_low = hi_or_low.groupby([patient_col, cluster_col]).size().unstack(fill_value=0) + + +hi_or_low = hi_or_low.div(hi_or_low.sum(axis=1), axis=0).fillna(0) # get proportion of each cluster in each patient (all will sum to 1) + + +hi_low_cluster_variables = ( + pd.merge( + hi_or_low.reset_index(), + clusters_patient[clinical_variables + [patient_col]], + on=patient_col, + ) + .drop_duplicates() + .reset_index(drop=True) +) + +prop_cluster_cols = [ + i + for i in hi_low_cluster_variables.columns + if i in clusters_patient[cluster_col].unique() +] +exception_variables = [] + +dfs = [] + +for cvar in clinical_variables: + cvar_dfs = [] + filtered_df = hi_low_cluster_variables[ + ~hi_low_cluster_variables[cvar].isna() + ].copy() # drop nan values for each var + + for sub_cvar in filtered_df[cvar].unique(): + print(cvar, sub_cvar) + selected_df = filtered_df.copy() + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + sub_cvar_df = pd.DataFrame({}) + y = selected_df[ + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + X = selected_df[ + prop_cluster_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + tvalues = {} + for cluster in range(X.shape[1]): # do each cluster one by one since these add up to 1 and Logit won't work + X1 = X[:, cluster] + X1 = sm.add_constant(X1) + try: + log_reg = sm.Logit(y, X1).fit() # fit the Logistic Regression model + + tvalues[cluster] = log_reg.tvalues[ + 1 + ] # there will be 2 t values, first one belongs to the constant + + except Exception as e: + exception_variables.append((cvar, sub_cvar, cluster, e)) + print( + f"{cvar}:{sub_cvar} had an exception occur for cluster {cluster}: {e}" + ) + + sub_cvar_df["cluster"] = list(tvalues.keys()) + + sub_cvar_df["tvalues"] = list(tvalues.values()) + + sub_cvar_df["clinical_variable"] = [f"{cvar}:{sub_cvar}"] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) + +full_cluster_clin_df = pd.concat(dfs).reset_index(drop=True) + +full_cluster_clin_df = pd.pivot_table( + full_cluster_clin_df, index="clinical_variable", values="tvalues", columns="cluster" +) # df that's plotted + +## Option 2: Number of cells per mm^2 tissue +# We're going to do this per image for now -- mainly because sizes might differ between images that belong to the same patient + +clinical_variables = clinical_variables + [ + "diseasestatus" +] # for basel, since doing per image + +cohort_dirs = { + "basel": ["OMEnMasks/Basel_Zuri_masks", "_a0_full_maks.tiff"], + "metabric": ["METABRIC_IMC/to_public_repository/cell_masks", "_cellmask.tiff"], + "melanoma": [ + "full_data/protein/cpout/", + "_ac_ilastik_s2_Probabilities_equalized_cellmask.tiff", + ], +} # directories with the masks + +adata_df = adata.obs.reset_index()[["cell_id", "Sample_name", "leiden"]] +clusters = adata_df.leiden.unique().tolist() + +sample_dfs = [] + +progress = Progress( + TextColumn(f"[progress.description]Finding cluster prevalances in {cohort}."), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), +) +with progress: + for sample in progress.track(adata.obs.Sample_name.unique()): + s_df = pd.DataFrame({}) + s_cluster_prevs = {} + mask = tifffile.imread( + f"../../../data/{cohort_dirs[cohort][0]}/{sample}{cohort_dirs[cohort][1]}" + ) # get dims of the image + sample_df = adata_df.copy().query("Sample_name==@sample") + for cluster in clusters: + num_cells_in_sample = Counter(sample_df.leiden.tolist()) + num_cells_in_clusters = num_cells_in_sample[cluster] # get number of cells belong to each cluster for each image + + # print(num_cells_in_clusters) + # print(mask.shape[0] , mask.shape[1]) + + cluster_prevalance_per_mm2 = ( + num_cells_in_clusters / (mask.shape[0] * mask.shape[1]) + ) * 1e6 # scale, 1 pixel == 1 micron, get the prevalence and scale + + s_cluster_prevs[cluster] = cluster_prevalance_per_mm2 + + s_df["cluster"] = list(s_cluster_prevs.keys()) + s_df["prevalance_per_mm2_scaled_by_1e6"] = list(s_cluster_prevs.values()) + s_df["Sample_name"] = [sample] * s_df.shape[0] + + sample_dfs.append(s_df) + +full_cohort_df = pd.concat(sample_dfs) + +full_cohort_df["cluster"] = full_cohort_df["cluster"].map(int) + +full_cohort_df = pd.pivot_table( + full_cohort_df, + values="prevalance_per_mm2_scaled_by_1e6", + index="Sample_name", + columns="cluster", +) + +clusters = full_cohort_df.columns.tolist() # to make sure correct order later + +cluster_per_tissue_patient = pd.merge( + full_cohort_df, patient_data[clinical_variables + ["Sample_name"]], on="Sample_name" +) + +# The below is still being run and tested but this is close to what I will be doing + +cluster_cols = clusters +exception_variables = [] + +dfs = [] + +for cvar in clinical_variables: + cvar_dfs = [] + + for sub_cvar in cluster_per_tissue_patient[cvar].dropna().unique().tolist(): + print(cvar, sub_cvar) + sub_cvar_df = pd.DataFrame({}) + selected_df = cluster_per_tissue_patient.copy()[ + ~cluster_per_tissue_patient[cvar].isna() + ] # drop nan values for each var + selected_df[cvar] = list(map(int, selected_df[cvar] == sub_cvar)) + + X = selected_df[ + cluster_cols + ].to_numpy() # select columns corresponding to latent dims and convert to numpy + X = sm.add_constant(X) # add constant + y = selected_df[ + cvar + ].to_numpy() # select the clinical variable column and convert to numpy -- no fillna(0) since all the nans should have been dropped + try: + log_reg = sm.Logit(y, X).fit() # fit the Logistic Regression model, doing them altogether this time, not one by one because don't need to + + sub_cvar_df["cluster"] = [c for c in cluster_cols] + + sub_cvar_df["tvalues"] = log_reg.tvalues[1:] # remove the constant + + sub_cvar_df["clinical_variable"] = [ + f"{cvar}:{sub_cvar}" + ] * sub_cvar_df.shape[0] + + cvar_dfs.append(sub_cvar_df) + except Exception as e: + exception_variables.append((cvar, sub_cvar)) # I keep a track of the exception variables but I don't deal with them + print(f"{cvar}:{sub_cvar} had an exception occur: {e}") + + full_cvar_dfs = pd.concat(cvar_dfs) + + dfs.append(full_cvar_dfs) diff --git a/hmivae/run_hmivae.py b/hmivae/run_hmivae.py new file mode 100644 index 0000000..ce5fb68 --- /dev/null +++ b/hmivae/run_hmivae.py @@ -0,0 +1,754 @@ +## run with hmivae +import argparse +import os +import time +from collections import OrderedDict + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# import phenograph +import scanpy as sc +import seaborn as sns +import squidpy as sq +import torch +import wandb +from anndata import AnnData +from rich.progress import ( # track, + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) +from scipy.stats.mstats import winsorize +from sklearn.preprocessing import StandardScaler +from statsmodels.api import OLS, add_constant + +# import hmivae +from hmivae._hmivae_model import hmivaeModel +from hmivae.ScModeDataloader import ScModeDataloader + + +def arrange_features(vars_lst, adata): + + arranged_features = {"E": [], "C": [], "M": [], "S": []} + orig_list = vars_lst + + for i in orig_list: + + if i in adata.var_names: + arranged_features["E"].append(i) + elif i in adata.uns["names_morphology"]: + arranged_features["M"].append(i) + elif i in adata.uns["names_correlations"]: + arranged_features["C"].append(i) + else: + arranged_features["S"].append(i) + arr_list = [ + *np.sort(arranged_features["E"]).tolist(), + *np.sort(arranged_features["C"]).tolist(), + *np.sort(arranged_features["M"]).tolist(), + *np.sort(arranged_features["S"]).tolist(), + ] + return arranged_features, arr_list + + +def create_cluster_dummy(adata, cluster_col, cluster): + # n_clusters = len(adata.obs[cluster_col].unique().tolist()) + x = np.zeros([adata.X.shape[0], 1]) + + for cell in adata.obs.index: + # cell_cluster = int(adata.obs[cluster_col][cell]) + # print(type(cell), type(cluster)) + + if adata.obs[cluster_col][int(cell)] == cluster: + x[int(cell)] = 1 + + return x + + +def get_feature_matrix(adata, scale_values=False, cofactor=1, weights=True): + + correlations = adata.obsm["correlations"] + if weights: + correlations = np.multiply( + correlations, adata.obsm["weights"] + ) # multiply weights with correlations + + if scale_values: + morphology = adata.obsm["morphology"] + for i in range(adata.obsm["morphology"].shape[1]): + morphology[:, i] = winsorize( + adata.obsm["morphology"][:, i], limits=[0, 0.01] + ) + + expression = np.arcsinh(adata.X / cofactor) + for j in range(adata.X.shape[1]): + expression[:, j] = winsorize(expression[:, j], limits=[0, 0.01]) + else: + morphology = adata.obsm["morphology"] + expression = adata.X + + y = StandardScaler().fit_transform( + np.concatenate([expression, correlations, morphology], axis=1) + ) + + var_names = np.concatenate( + [ + adata.var_names, + adata.uns["names_correlations"], + adata.uns["names_morphology"], + ] + ) + + return y, var_names + + +def rank_features_in_groups(adata, group_col, scale_values=False, cofactor=1): + + progress = Progress( + TextColumn(f"[progress.description]Ranking features in {group_col} groups"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + ) + ranked_features_in_groups = {} + dfs = [] + # create the feature matrix for entire adata + y, var_names = get_feature_matrix( + adata, scale_values=scale_values, cofactor=cofactor + ) + y = add_constant(y) # add intercept + + with progress: + + for group in progress.track(adata.obs[group_col].unique()): + ranked_features_in_groups[str(group)] = {} + x = create_cluster_dummy(adata, group_col, group) + mod = OLS(x, y) + res = mod.fit() + + df_values = pd.DataFrame( + res.tvalues[1:], # remove the intercept value + index=var_names, + columns=[f"t_value_{group}"], + ).sort_values(by=f"t_value_{group}", ascending=False) + + ranked_features_in_groups[str(group)]["names"] = df_values.index.to_list() + ranked_features_in_groups[str(group)]["t_values"] = df_values[ + f"t_value_{group}" + ].to_list() + + # print('df index:', df_values.index.tolist()) + + dfs.append(df_values) + + fc_df = pd.concat( + dfs, axis=1 + ).sort_index() # index is sorted as alphabetical! (order with original var_names is NOT maintained!) + + fc_df.index = fc_df.index.map(str) + fc_df.columns = fc_df.columns.map(str) + + adata.uns[f"{group_col}_ranked_features_in_groups"] = ranked_features_in_groups + adata.uns[f"{group_col}_feature_scores"] = fc_df + + # return adata + + +def top_common_features(df, top_n_features=10): + + sets_list = [] + + for i in df.columns: + abs_sorted_col = df[i].map(abs).sort_values(ascending=False) + for j in abs_sorted_col.index.to_list()[0:top_n_features]: + sets_list.append(j) + + common_features = list(set(sets_list)) + + common_feat_df = df.loc[common_features] + + return common_feat_df + + +parser = argparse.ArgumentParser(description="Run hmiVAE") + +parser.add_argument( + "--adata", type=str, required=True, help="AnnData file with all the inputs" +) + +parser.add_argument( + "--include_all_views", + type=int, + help="Run model using all views", + default=1, + choices=[0, 1], +) + +parser.add_argument( + "--remove_view", + type=str, + help="Name of view to leave out. One of ['expression', 'correlation', 'morphology', 'spatial']. Must be given when `include_all_views` is False", + default=None, + choices=["expression", "correlation", "morphology", "spatial"], +) + +parser.add_argument( + "--use_covs", + type=bool, + help="True/False for using background covariates", + default=True, +) + +parser.add_argument( + "--use_weights", + type=bool, + help="True/False for using correlation weights", + default=True, +) + +parser.add_argument( + "--batch_correct", + type=bool, + help="True/False for using one-hot encoding for batch correction", + default=True, +) + +parser.add_argument( + "--batch_size", + type=int, + help="Batch size for train/test data, default=1234", + default=1234, +) + +parser.add_argument( + "--hidden_dim_size", + type=int, + help="Size for view-specific hidden layers", + default=32, +) + +parser.add_argument( + "--latent_dim", + type=int, + help="Size for the final latent representation layer", + default=10, +) + +parser.add_argument( + "--n_hidden", + type=int, + help="Number of hidden layers", + default=1, +) + +parser.add_argument( + "--beta_scheme", + type=str, + help="Scheme to use for beta vae", + default="warmup", + choices=["constant", "warmup"], +) + +parser.add_argument( + "--use_linear_decoder", + type=bool, + help="For using a linear decoder: True or False", + default=False, +) + +parser.add_argument( + "--cofactor", type=float, help="Cofactor for arcsinh transformation", default=1.0 +) + +parser.add_argument( + "--random_seed", + type=int, + help="Random seed for weights initialization", + default=1234, +) + +parser.add_argument("--cohort", type=str, help="Cohort name", default="cohort") + +parser.add_argument( + "--output_dir", type=str, help="Directory to store the outputs", default="." +) + +args = parser.parse_args() + +log_file = open( + os.path.join( + args.output_dir, + f"{args.cohort}_nhidden{args.n_hidden}_hiddendim{args.hidden_dim_size}_latentdim{args.latent_dim}_betascheme{args.beta_scheme}_randomseed{args.random_seed}_run_log.txt", + ), + "w+", +) + +raw_adata = sc.read_h5ad(args.adata) + +# print("connections", adata.obsp["connectivities"]) +# print("raw adata X min,max", raw_adata.X.max(), raw_adata.X.min()) +# print("raw adata corrs min,max", raw_adata.obsm['correlations'].max(), raw_adata.obsm['correlations'].min()) +# print("raw adata morph min,max", raw_adata.obsm['morphology'].max(), raw_adata.obsm['morphology'].min()) + +L = [ + f"raw adata X, max: {raw_adata.X.max()}, min: {raw_adata.X.min()} \n", + f"raw adata correlations, max: {raw_adata.obsm['correlations'].max()}, min: {raw_adata.obsm['correlations'].min()} \n", + f"raw adata morphology, max: {raw_adata.obsm['morphology'].max()}, min: {raw_adata.obsm['morphology'].min()} \n", +] + +log_file.writelines(L) +n_total_features = ( + raw_adata.X.shape[1] + + raw_adata.obsm["correlations"].shape[1] + + raw_adata.obsm["morphology"].shape[1] +) + +log_file.write(f"Total number of features:{n_total_features} \n") +log_file.write(f"Total number of cells:{raw_adata.X.shape[0]} \n") + +print("Set up the model") + +start = time.time() + + +E_me, E_cr, E_mr, E_sc = [ + args.hidden_dim_size, + args.hidden_dim_size, + args.hidden_dim_size, + args.hidden_dim_size, +] +input_exp_dim, input_corr_dim, input_morph_dim, input_spcont_dim = [ + raw_adata.shape[1], + raw_adata.obsm["correlations"].shape[1], + raw_adata.obsm["morphology"].shape[1], + n_total_features, +] +keys = [] +if args.use_covs: + cat_list = [] + + for key in raw_adata.obsm.keys(): + # print(key) + if key not in ["correlations", "morphology", "spatial", "xy"]: + keys.append(key) + for cat_key in keys: + # print(cat_key) + # print(f"{cat_key} shape:", adata.obsm[cat_key].shape) + category = raw_adata.obsm[cat_key] + cat_list.append(category) + cat_list = np.concatenate(cat_list, 1) + n_covariates = cat_list.shape[1] + E_cov = args.hidden_dim_size +else: + n_covariates = 0 + E_cov = 0 + +model = hmivaeModel( + adata=raw_adata, + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + E_cov=E_cov, + latent_dim=args.latent_dim, + cofactor=args.cofactor, + use_covs=args.use_covs, + cohort=args.cohort, + use_weights=args.use_weights, + beta_scheme=args.beta_scheme, + linear_decoder=args.use_linear_decoder, + n_covariates=n_covariates, + batch_correct=args.batch_correct, + batch_size=args.batch_size, + random_seed=args.random_seed, + n_hidden=args.n_hidden, + leave_out_view=args.remove_view, + output_dir=args.output_dir, +) + + +print("Start training") + + +model.train() + +wandb.finish() + +model_checkpoint = [ + i for i in os.listdir(args.output_dir) if ".ckpt" in i +] # should only be 1 -- saved best model + +print("model_checkpoint", model_checkpoint) + +load_chkpt = torch.load(os.path.join(args.output_dir, model_checkpoint[0])) + +state_dict = load_chkpt["state_dict"] +# print(state_dict) +new_state_dict = OrderedDict() +for k, v in state_dict.items(): + # print("key", k) + if "weight" or "bias" in k: + # print("changing", k) + name = "module." + k # add `module.` + # print("new name", name) + else: + # print("staying same", k) + name = k + new_state_dict[name] = v +# load params + +load_chkpt["state_dict"] = new_state_dict + +# torch.save(os.path.join(args.output_dir, model_checkpoint[0])) + +model = hmivaeModel( + adata=raw_adata, + input_exp_dim=input_exp_dim, + input_corr_dim=input_corr_dim, + input_morph_dim=input_morph_dim, + input_spcont_dim=input_spcont_dim, + E_me=E_me, + E_cr=E_cr, + E_mr=E_mr, + E_sc=E_sc, + E_cov=E_cov, + latent_dim=args.latent_dim, + cofactor=args.cofactor, + use_covs=args.use_covs, + use_weights=args.use_weights, + beta_scheme=args.beta_scheme, + linear_decoder=args.use_linear_decoder, + n_covariates=n_covariates, + batch_correct=args.batch_correct, + batch_size=args.batch_size, + random_seed=args.random_seed, + n_hidden=args.n_hidden, + leave_out_view=args.remove_view, + output_dir=args.output_dir, +) +model.load_state_dict(new_state_dict, strict=False) + + +# model.load_from_checkpoint(os.path.join(args.output_dir, model_checkpoint[0]), adata=raw_adata) + +print("Best model loaded from checkpoint") + +stop = time.time() + +log_file.write(f"All training done in {(stop-start)/60} minutes \n") + +starta = time.time() + +adata = model.get_latent_representation( # use the best model to get the latent representations + adata=raw_adata, + protein_correlations_obsm_key="correlations", + cell_morphology_obsm_key="morphology", + continuous_covariate_keys=keys, + is_trained_model=True, + batch_correct=args.batch_correct, +) + +print("Doing cluster and neighbourhood enrichment analysis") + +print("===> Clustering using integrated space") + +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="VAE", key_added="vae" +) # 100 nearest neighbours, will be used in downstream tests -- keep with PG + + +sc.tl.leiden(adata, neighbors_key="vae") + +print("===> Clustering using specific views") + +print("Expression") + +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="expression_embedding", key_added="expression" +) # 100 nearest neighbours, will be used in downstream tests -- keep with PG + +sc.tl.leiden( + adata, + neighbors_key="expression", + key_added="expression_leiden", + random_state=args.random_seed, + resolution=0.5, +) # expression wasn't too bad + +print("Correlation") + +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="correlation_embedding", key_added="correlation" +) + +sc.tl.leiden( + adata, + neighbors_key="correlation", + key_added="correlation_leiden", + random_state=args.random_seed, +) # probably no need to change correlation because there were few anyways + +print("Morphology") + +sc.pp.neighbors( + adata, n_neighbors=100, use_rep="morphology_embedding", key_added="morphology" +) + +sc.tl.leiden( + adata, + neighbors_key="morphology", + key_added="morphology_leiden", + random_state=args.random_seed, + resolution=0.1, +) # pull it way down because there were LOTS of clusters + +print("Spatial context") + +sc.pp.neighbors( + adata, + n_neighbors=100, + use_rep="spatial_context_embedding", + key_added="spatial_context", +) + +sc.tl.leiden( + adata, + neighbors_key="spatial_context", + key_added="spatial_context_leiden", + random_state=args.random_seed, + resolution=0.5, +) + +print("===> Creating UMAPs") + +print("Integrated space") + +sc.tl.umap(adata, neighbors_key="vae", random_state=args.random_seed) + +adata.obsm["X_umap_int"] = adata.obsm["X_umap"].copy() + +print("Expression") + +sc.tl.umap(adata, neighbors_key="expression", random_state=args.random_seed) + +adata.obsm["X_umap_exp"] = adata.obsm["X_umap"].copy() + +print("Correlations") + +sc.tl.umap(adata, neighbors_key="correlation", random_state=args.random_seed) + +adata.obsm["X_umap_corr"] = adata.obsm["X_umap"].copy() + +print("Morphology") + +sc.tl.umap(adata, neighbors_key="morphology", random_state=args.random_seed) + +adata.obsm["X_umap_morph"] = adata.obsm["X_umap"].copy() + +print("Spatial context") + +sc.tl.umap(adata, neighbors_key="spatial_context", random_state=args.random_seed) + +adata.obsm["X_umap_spct"] = adata.obsm["X_umap"].copy() +# ranked_dict, fc_df = +# rank_features_in_groups( +# adata, "leiden", scale_values=False, cofactor=args.cofactor, +# ) # no scaling required because using adata_train and test which have already been normalized and winsorized -- StandardScaler still applied +# fc_df = adata.uns["leiden_feature_scores"] + +# top5_leiden = top_common_features(fc_df) + +# if args.include_all_views: + +# top5_leiden.to_csv( +# os.path.join(args.output_dir, f"{args.cohort}_top5_features_across_clusters_leiden.tsv"), +# sep="\t", +# ) + +print("Neighbourhood enrichment analysis") + +# sq.gr.co_occurrence(adata, cluster_key="leiden") # if it works, it works -- didn't work, always NaNs + +sq.gr.spatial_neighbors(adata) +sq.gr.nhood_enrichment(adata, cluster_key="leiden") + + +print("===> Create the neighbourhood features") + +h5 = adata.copy() + +sc.pp.neighbors( + h5, use_rep="spatial", n_neighbors=10 +) # get spatial neighbour connectivities, we lose this when we make the new adata + +data = ScModeDataloader(h5) + +spatial_context = data.C.numpy() + +spatial_context_names = [ + "neighbour_" + i + for i in list(h5.var_names) + + h5.uns["names_correlations"].tolist() + + h5.uns["names_morphology"].tolist() +] + +print("===> Creating new adata and ranking all features") + +clustering = [i for i in h5.obs.columns if "leiden" in i] + +all_features = np.concatenate( + [h5.X, h5.obsm["correlations"], h5.obsm["morphology"], spatial_context], axis=1 +) + +names = np.concatenate( + [ + h5.var_names, + h5.uns["names_correlations"], + h5.uns["names_morphology"], + spatial_context_names, + ] +) + +all_features_df = pd.DataFrame(all_features, columns=names) + + +new_adata = AnnData( + X=all_features_df, + obs=h5.copy().obs, + obsm=h5.copy().obsm, + obsp=h5.copy().obsp, + uns=h5.copy().uns, +) + +for cl in clustering: + print(f"Ranking features for clustering: {cl}") + sc.tl.rank_genes_groups(new_adata, groupby=cl, key_added=f"{cl}_rank_gene_groups") + +dfs = [] + +for cl in clustering: + ranked_df = sc.get.rank_genes_groups_df( + new_adata, group=None, key=f"{cl}_rank_gene_groups" + ) + + ranked_df["clustering"] = [cl] * ranked_df.shape[0] + + dfs.append(ranked_df) + +full_ranked_df = pd.concat(dfs) + +## get the top features across all the different clustering + +dfs2 = {} + +for cl in clustering: + print(f"sorting ranked features for {cl}") + fs = [] + features_df = full_ranked_df.copy().query("clustering==@cl") + for group in features_df.group.unique(): + group_df = features_df.query("group==@group") + top10 = group_df.names.tolist()[0:10] # these are sorted by top + + for f in top10: + fs.append(f) + + top_features = list(set(fs)) + + new_df = pd.DataFrame({}) + + for group in features_df.group.unique(): + group_df = features_df.query("group==@group") + + # print('df shape', group_df.shape[0]) + + scores = ( + group_df.loc[group_df.names.isin(top_features), ["names", "scores"]] + .set_index("names") + .sort_index() + .scores.tolist() + ) + + new_df[group] = scores + + # print(group, new_df.shape) + + new_df.index = np.sort(top_features) + + arr_features2, arr_list2 = arrange_features(new_df.index.to_list(), adata) + + new_df = new_df.reindex(arr_list2) + + new_df.columns = new_df.columns.map(int) + + new_df = new_df[np.sort(new_df.columns)] + + # new_df['clustering'] = [cl]*new_df.shape[0] + + dfs2[cl] = new_df + +cmap = sns.diverging_palette(220, 20, as_cmap=True) + +for n, cl in enumerate(clustering): + # print(n) + # bx = plt.subplot(6,1,n+1) + sns.clustermap( + dfs2[cl].fillna(0), + row_cluster=False, + center=0.00, + cmap=cmap, + vmin=-100, + vmax=100, + figsize=(25, 25), + linewidth=2, + linecolor="black", + ) + + # plt.title(f"rankings for {cl}") + + plt.savefig(f"{args.cohort}_cluster_rankings_for_{cl}.png") + +print("old", new_adata.uns.keys()) + +new_uns = {str(k): v for k, v in new_adata.uns.items()} + +print("new", new_uns.keys()) + +adata.uns = new_uns + +if args.include_all_views: + new_adata.obs.to_csv( + os.path.join(args.output_dir, f"{args.cohort}_clusters.tsv"), sep="\t" + ) + new_adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) + full_ranked_df.to_csv( + os.path.join(args.output_dir, f"{args.cohort}_clusters_ranked_features.tsv"), + sep="\t", + ) + +# if args.include_all_views: +# adata.obs.to_csv(os.path.join(args.output_dir, f"{args.cohort}_clusters.tsv"), sep="\t") +# adata.write(os.path.join(args.output_dir, f"{args.cohort}_adata_new.h5ad")) + +else: + adata.obs.to_csv( + os.path.join( + args.output_dir, f"{args.cohort}_remove_{args.remove_view}_clusters.tsv" + ), + sep="\t", + ) + adata.write( + os.path.join( + args.output_dir, f"{args.cohort}_adata_remove_{args.remove_view}.h5ad" + ) + ) + + +# sc.pl.umap(adata[random_inds], color=['leiden'], show diff --git a/mypackage/_mymodel.py b/mypackage/_mymodel.py deleted file mode 100644 index 8b4faea..0000000 --- a/mypackage/_mymodel.py +++ /dev/null @@ -1,108 +0,0 @@ -import logging -from typing import List, Optional - -from anndata import AnnData -from scvi.data import setup_anndata -from scvi.model._utils import _init_library_size -from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin -from scvi.utils import setup_anndata_dsp - -from ._mymodule import MyModule - -logger = logging.getLogger(__name__) - - -class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): - """ - Skeleton for an scvi-tools model. - - Please use this skeleton to create new models. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`. - n_hidden - Number of nodes per hidden layer. - n_latent - Dimensionality of the latent space. - n_layers - Number of hidden layers used for encoder and decoder NNs. - **model_kwargs - Keyword args for :class:`~mypackage.MyModule` - Examples - -------- - >>> adata = anndata.read_h5ad(path_to_anndata) - >>> mypackage.MyModel.setup_anndata(adata, batch_key="batch") - >>> vae = mypackage.MyModel(adata) - >>> vae.train() - >>> adata.obsm["X_mymodel"] = vae.get_latent_representation() - """ - - def __init__( - self, - adata: AnnData, - n_hidden: int = 128, - n_latent: int = 10, - n_layers: int = 1, - **model_kwargs, - ): - super(MyModel, self).__init__(adata) - - library_log_means, library_log_vars = _init_library_size( - adata, self.summary_stats["n_batch"] - ) - - # self.summary_stats provides information about anndata dimensions and other tensor info - - self.module = MyModule( - n_input=self.summary_stats["n_vars"], - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - library_log_means=library_log_means, - library_log_vars=library_log_vars, - **model_kwargs, - ) - self._model_summary_string = "Overwrite this attribute to get an informative representation for your model" - # necessary line to get params that will be used for saving/loading - self.init_params_ = self._get_init_params(locals()) - - logger.info("The model has been initialized") - - @staticmethod - @setup_anndata_dsp.dedent - def setup_anndata( - adata: AnnData, - batch_key: Optional[str] = None, - labels_key: Optional[str] = None, - layer: Optional[str] = None, - categorical_covariate_keys: Optional[List[str]] = None, - continuous_covariate_keys: Optional[List[str]] = None, - copy: bool = False, - ) -> Optional[AnnData]: - """ - %(summary)s. - Parameters - ---------- - %(param_adata)s - %(param_batch_key)s - %(param_labels_key)s - %(param_layer)s - %(param_cat_cov_keys)s - %(param_cont_cov_keys)s - %(param_copy)s - - Returns - ------- - %(returns)s - """ - return setup_anndata( - adata, - batch_key=batch_key, - labels_key=labels_key, - layer=layer, - categorical_covariate_keys=categorical_covariate_keys, - continuous_covariate_keys=continuous_covariate_keys, - copy=copy, - ) diff --git a/mypackage/_mymodule.py b/mypackage/_mymodule.py deleted file mode 100644 index ed41b30..0000000 --- a/mypackage/_mymodule.py +++ /dev/null @@ -1,293 +0,0 @@ -import numpy as np -import torch -import torch.nn.functional as F -from scvi import _CONSTANTS -from scvi.distributions import ZeroInflatedNegativeBinomial -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data -from scvi.nn import DecoderSCVI, Encoder, one_hot -from torch.distributions import Normal -from torch.distributions import kl_divergence as kl - -torch.backends.cudnn.benchmark = True - - -class MyModule(BaseModuleClass): - """ - Skeleton Variational auto-encoder model. - - Here we implement a basic version of scVI's underlying VAE [Lopez18]_. - This implementation is for instructional purposes only. - - Parameters - ---------- - n_input - Number of input genes - library_log_means - 1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if - not using observed library size. - library_log_vars - 1 x n_batch array of variances of the log library sizes. Parameterizes prior on library size if - not using observed library size. - n_batch - Number of batches, if 0, no batch correction is performed. - n_hidden - Number of nodes per hidden layer - n_latent - Dimensionality of the latent space - n_layers - Number of hidden layers used for encoder and decoder NNs - dropout_rate - Dropout rate for neural networks - """ - - def __init__( - self, - n_input: int, - library_log_means: np.ndarray, - library_log_vars: np.ndarray, - n_batch: int = 0, - n_hidden: int = 128, - n_latent: int = 10, - n_layers: int = 1, - dropout_rate: float = 0.1, - ): - super().__init__() - self.n_latent = n_latent - self.n_batch = n_batch - # this is needed to comply with some requirement of the VAEMixin class - self.latent_distribution = "normal" - - self.register_buffer( - "library_log_means", torch.from_numpy(library_log_means).float() - ) - self.register_buffer( - "library_log_vars", torch.from_numpy(library_log_vars).float() - ) - - # setup the parameters of your generative model, as well as your inference model - self.px_r = torch.nn.Parameter(torch.randn(n_input)) - # z encoder goes from the n_input-dimensional data to an n_latent-d - # latent space representation - self.z_encoder = Encoder( - n_input, - n_latent, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - ) - # l encoder goes from n_input-dimensional data to 1-d library size - self.l_encoder = Encoder( - n_input, - 1, - n_layers=1, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - ) - # decoder goes from n_latent-dimensional space to n_input-d data - self.decoder = DecoderSCVI( - n_latent, - n_input, - n_layers=n_layers, - n_hidden=n_hidden, - ) - - def _get_inference_input(self, tensors): - """Parse the dictionary to get appropriate args""" - x = tensors[_CONSTANTS.X_KEY] - - input_dict = dict(x=x) - return input_dict - - def _get_generative_input(self, tensors, inference_outputs): - z = inference_outputs["z"] - library = inference_outputs["library"] - - input_dict = { - "z": z, - "library": library, - } - return input_dict - - @auto_move_data - def inference(self, x): - """ - High level inference method. - - Runs the inference (encoder) model. - """ - # log the input to the variational distribution for numerical stability - x_ = torch.log(1 + x) - # get variational parameters via the encoder networks - qz_m, qz_v, z = self.z_encoder(x_) - ql_m, ql_v, library = self.l_encoder(x_) - - outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) - return outputs - - @auto_move_data - def generative(self, z, library): - """Runs the generative model.""" - - # form the parameters of the ZINB likelihood - px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library) - px_r = torch.exp(self.px_r) - - return dict( - px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout - ) - - def loss( - self, - tensors, - inference_outputs, - generative_outputs, - kl_weight: float = 1.0, - ): - x = tensors[_CONSTANTS.X_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] - - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) - - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) - - batch_index = tensors[_CONSTANTS.BATCH_KEY] - n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) - - kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), - Normal(local_library_log_means, torch.sqrt(local_library_log_vars)), - ).sum(dim=1) - - reconst_loss = ( - -ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=px_dropout) - .log_prob(x) - .sum(dim=-1) - ) - - kl_local_for_warmup = kl_divergence_z - kl_local_no_warmup = kl_divergence_l - - weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup - - loss = torch.mean(reconst_loss + weighted_kl_local) - - kl_local = dict( - kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z - ) - kl_global = torch.tensor(0.0) - return LossRecorder(loss, reconst_loss, kl_local, kl_global) - - @torch.no_grad() - def sample( - self, - tensors, - n_samples=1, - library_size=1, - ) -> np.ndarray: - r""" - Generate observation samples from the posterior predictive distribution. - - The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. - - Parameters - ---------- - tensors - Tensors dict - n_samples - Number of required samples for each cell - library_size - Library size to scale scamples to - - Returns - ------- - x_new : :py:class:`torch.Tensor` - tensor with shape (n_cells, n_genes, n_samples) - """ - inference_kwargs = dict(n_samples=n_samples) - _, generative_outputs, = self.forward( - tensors, - inference_kwargs=inference_kwargs, - compute_loss=False, - ) - - px_r = generative_outputs["px_r"] - px_rate = generative_outputs["px_rate"] - px_dropout = generative_outputs["px_dropout"] - - dist = ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) - - if n_samples > 1: - exprs = dist.sample().permute( - [1, 2, 0] - ) # Shape : (n_cells_batch, n_genes, n_samples) - else: - exprs = dist.sample() - - return exprs.cpu() - - @torch.no_grad() - @auto_move_data - def marginal_ll(self, tensors, n_mc_samples): - sample_batch = tensors[_CONSTANTS.X_KEY] - batch_index = tensors[_CONSTANTS.BATCH_KEY] - - to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) - - for i in range(n_mc_samples): - # Distribution parameters and sampled variables - inference_outputs, _, losses = self.forward(tensors) - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - z = inference_outputs["z"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] - library = inference_outputs["library"] - - # Reconstruction Loss - reconst_loss = losses.reconstruction_loss - - # Log-probabilities - n_batch = self.library_log_means.shape[1] - local_library_log_means = F.linear( - one_hot(batch_index, n_batch), self.library_log_means - ) - local_library_log_vars = F.linear( - one_hot(batch_index, n_batch), self.library_log_vars - ) - p_l = ( - Normal(local_library_log_means, local_library_log_vars.sqrt()) - .log_prob(library) - .sum(dim=-1) - ) - - p_z = ( - Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) - .log_prob(z) - .sum(dim=-1) - ) - p_x_zl = -reconst_loss - q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) - - to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x - - batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) - log_lkl = torch.sum(batch_log_lkl).item() - return log_lkl diff --git a/pl_vae_scripts_new/HMIDataset.py b/pl_vae_scripts_new/HMIDataset.py new file mode 100644 index 0000000..b463b40 --- /dev/null +++ b/pl_vae_scripts_new/HMIDataset.py @@ -0,0 +1,28 @@ +import os + +import pandas as pd +import scanpy as sc +from torch.utils.data import TensorDataset + + +class HMIDataset(TensorDataset): + def __init__( + self, + h5ad_dir, + ): + """ + Input is a directory with all h5ad files. + h5ad_dir: Directory containing all h5ad files for each image in HMI dataset + transform: Default is None. Any transformations to be applied to the h5ad files + """ + self.h5ad_dir = h5ad_dir + self.h5ad_names = pd.DataFrame({"Sample_names": os.listdir(h5ad_dir)}) + + def __len__(self): + return len(os.listdir(self.h5ad_dir)) + + def __getitem__(self, idx): + h5ad_path = os.path.join(self.h5ad_dir, self.h5ad_names.iloc[idx, 0]) + h5ad = sc.read_h5ad(h5ad_path) + + return h5ad diff --git a/pl_vae_scripts_new/pl_vae_run_refact.py b/pl_vae_scripts_new/pl_vae_run_refact.py new file mode 100644 index 0000000..d1265aa --- /dev/null +++ b/pl_vae_scripts_new/pl_vae_run_refact.py @@ -0,0 +1,454 @@ +# import argparse +# import os + +# import time + +# import anndata as ad +# import numpy as np +# import scanpy as sc +# import torch + +# import wandb +# from pl_vae_classes_and_func_refact import * +# from pytorch_lightning import Trainer + +# from scipy.stats.mstats import winsorize +# from ScModeDataloader import ScModeDataloader + +# from sklearn.model_selection import train_test_split + + +# def sparse_numpy_to_torch(adj_mat): +# """Construct sparse torch tensor +# Need to do csr -> coo +# then follow https://stackoverflow.com/questions/50665141/converting-a-scipy-coo-matrix-to-pytorch-sparse-tensor +# """ +# adj_mat_coo = adj_mat.tocoo() + +# values = adj_mat_coo.data +# indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col)) + +# i = torch.LongTensor(indices) +# v = torch.FloatTensor(values) +# shape = adj_mat_coo.shape + +# return torch.sparse_coo_tensor(i, v, shape) + + +# parser = argparse.ArgumentParser(description="Run emVAE, em2LVAE, dm2LVAE and dmVAE") + +# parser.add_argument( +# "--input_h5ad", +# type=str, +# required=True, +# help="h5ad file that contains mean expression and correlation information for one or more samples", +# ) + +# parser.add_argument("--lr", type=float, help="Learning rate for VAEs", default=0.001) + +# parser.add_argument( +# "--random_seed", +# type=int, +# required=False, +# help="Random seed for VAE initialization", +# default=1234, +# ) + +# parser.add_argument( +# "--train_ratio", +# type=float, +# help="Ratio of the full dataset to be treated as the training set", +# default=0.75, +# ) + +# parser.add_argument("--subset_to", type=int, help="Data subset size") + +# parser.add_argument( +# "--winsorize", type=int, help="0 or 1 to denote False or True", default=1 +# ) + +# parser.add_argument("--cofactor", type=float, help="Value for cofactor", default=5.0) + +# # parser.add_argument( +# # "--n_proteins", type=int, required=True, help="Number of proteins in the dataset" +# # ) + +# parser.add_argument( +# "--use_weights", type=int, help="0 or 1 to denote False or True", default=0 +# ) + +# parser.add_argument( +# "--apply_arctanh", type=int, help="0 or 1 to denote False or True", default=0 +# ) + +# parser.add_argument("--cohort", type=str, help="Name of cohort", default="None") + +# parser.add_argument("--beta", type=float, help="beta value for B-VAE", default=1.0) + +# parser.add_argument("--n_epochs", type=int, help="number of epochs", default=200) + +# parser.add_argument( +# "--apply_KLwarmup", +# type=int, +# help="0 or 1 as False or True, to apply a KL-warmup scheme, if not, then BETA is used as given", +# default=1, +# ) + +# parser.add_argument( +# "--regress_out_patient", +# type=int, +# help="0 or 1 as False or True, to regress out patient effects, default is False", +# default=0, +# ) + +# parser.add_argument( +# "--KL_limit", +# type=float, +# help="Max limit for the coefficient of the KL-Div term", +# default=0.3, +# ) + +# parser.add_argument( +# "--output_dir", type=str, help="Directory to store the outputs", default="." +# ) + +# args = parser.parse_args() + + +# adata = sc.read_h5ad(args.input_h5ad) + +# COFACTOR = args.cofactor + +# RANDOM_SEED = args.random_seed + +# N_EPOCHS = args.n_epochs +# N_HIDDEN = 2 +# HIDDEN_LAYER_SIZE_Eme = 8 +# HIDDEN_LAYER_SIZE_Ecr = 8 +# HIDDEN_LAYER_SIZE_Emr = 8 +# N_SPATIAL_CONTEXT = ( +# HIDDEN_LAYER_SIZE_Eme + HIDDEN_LAYER_SIZE_Ecr + HIDDEN_LAYER_SIZE_Emr +# ) +# HIDDEN_LAYER_SIZE_Esc = 8 # keeping this consistent with the previous Basel analysis + +# LATENT_DIM = 10 +# BATCH_SIZE = 256 +# CELLS_CUTOFF = 500 + +# N_TOTAL_CELLS = adata.shape[0] +# N_PROTEINS = adata.shape[1] +# N_CORRELATIONS = len(adata.uns["names_correlations"]) +# N_MORPHOLOGY = len(adata.uns["names_morphology"]) + +# N_TOTAL_FEATURES = N_PROTEINS + N_CORRELATIONS + N_MORPHOLOGY + +# BETA = args.beta ## beta for beta-vae + +# TRAIN_PROP = args.train_ratio # set the training set ratio + +# lr = args.lr # set the learning rate + +# # log_py = {} +# # elbo_losses = {} + + +# if args.output_dir is not None: +# output_dir = args.output_dir +# if not os.path.exists(output_dir): +# os.makedirs(output_dir) +# else: +# output_dir = "." + + +# # Set up the data +# np.random.seed(RANDOM_SEED) + +# if args.subset_to is not None: +# print("Subsetting samples") +# samples = adata.obs.Sample_name.unique().to_list() +# inds = np.random.choice(samples, args.subset_to) +# adata = adata[adata.obs.Sample_name.isin(inds)] +# else: +# adata = adata + + +# # adata.obs = adata.obs.reset_index() +# # adata.obs.columns = ["index", "Sample_name", "cell_id"] + + +# if adata.X.shape[0] > 705000: +# sample_drop_lst = [] +# for sample in adata.obs[ +# "Sample_name" +# ].unique(): # if sample has less than 500 cells, drop it +# if ( +# adata.obs.query("Sample_name==@sample").shape[0] < CELLS_CUTOFF +# ): # true for <235 samples out of all samples +# sample_drop_lst.append(sample) + +# adata_sub = adata.copy()[ +# ~adata.obs.Sample_name.isin(sample_drop_lst), : +# ] # select all rows except those that belong to samples w cells < CELLS_CUTOFF + +# adata_sub.obs = adata_sub.obs.reset_index() +# if "level_0" in adata_sub.obs.columns: +# adata_sub.obs = adata_sub.obs.drop(columns=["level_0"]) + +# else: +# adata_sub = adata + + +# print("Preprocessing data views") + +# if args.cofactor is not None: +# adata_sub.X = np.arcsinh(adata_sub.X / COFACTOR) + + +# if args.winsorize == 1: +# for i in range(N_PROTEINS): +# adata_sub.X[:, i] = winsorize(adata_sub.X[:, i], limits=[0, 0.01]) +# for i in range(N_MORPHOLOGY): +# adata_sub.obsm["morphology"][:, i] = winsorize( +# adata_sub.obsm["morphology"][:, i], limits=[0, 0.01] +# ) + +# if args.apply_arctanh == 1: +# adata_sub.obsm["correlations"] = np.arctanh(adata_sub.obsm["correlations"]) + + +# adata_sub.obs["Sample_name"] = adata_sub.obs["Sample_name"].astype( +# str +# ) # have to do this otherwise it will contain the ones that were removed + +# if args.regress_out_patient: +# print("Regressing out patient effect") +# sc.pp.regress_out(adata_sub, "Sample_name") + +# samples_list = adata_sub.obs["Sample_name"].unique().tolist() # samples in the adata + + +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# train_size = int(np.floor(len(samples_list) * TRAIN_PROP)) +# test_size = len(samples_list) - train_size + +# # separate images/samples as train or test *only* (this is different from before, when we separated cells into train/test) + +# print("Setting up train and test data") + +# samples_train, samples_test = train_test_split( +# samples_list, train_size=TRAIN_PROP, random_state=RANDOM_SEED +# ) + +# adata_train = adata_sub.copy()[adata_sub.obs["Sample_name"].isin(samples_train), :] +# adata_test = adata_sub.copy()[adata_sub.obs["Sample_name"].isin(samples_test), :] + + +# data_train = ScModeDataloader(adata_train) +# data_test = ScModeDataloader(adata_test, data_train.scalers) + +# loader_train = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True) +# loader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=True) + + +# model = hmiVAE( +# N_PROTEINS, +# N_CORRELATIONS, +# N_MORPHOLOGY, +# N_SPATIAL_CONTEXT, +# HIDDEN_LAYER_SIZE_Eme, +# HIDDEN_LAYER_SIZE_Ecr, +# HIDDEN_LAYER_SIZE_Emr, +# HIDDEN_LAYER_SIZE_Esc, +# LATENT_DIM, +# ) + +# trainer = Trainer() + +# wandb.init( +# project="vae_new_morphs", +# entity="sayub", +# config={ +# "learning_rate": lr, +# "epochs": N_EPOCHS, +# "batch_size": BATCH_SIZE, +# "use_weights": int(corr_weights), +# "use_arctanh": int(args.apply_arctanh), +# "n_cells": len(data_train), +# "method": "emVAE", +# "cohort": args.cohort, +# "BETA": BETA, +# "cofactor": COFACTOR, +# "regress_out_patient": args.regress_out_patient, +# "apply_KL_warmup": args.apply_KLwarmup, +# "KL_max_limit": args.KL_limit, +# }, +# ) + + +# # start_time_em = time.time() + + +# # ## Reconstruction weights +# # r = np.array([1., 1., 1., 0.]) + +# # for n in range(N_EPOCHS): +# # if args.apply_KLwarmup: +# # if n>5: +# # new_beta = BETA + 0.05 +# # BETA = min(new_beta, args.KL_limit) +# # else: +# # BETA = BETA + +# # if n > 5: +# # spcont_r = r[3]+0.1 +# # r[3] = min(spcont_r, 1.0) +# # if n < 30: +# # print(r) +# # #r = np.array([1., 1., 1., 1.]) + +# # optimizer_em.zero_grad() + +# # train_losses_em = 0 +# # num_batches_em = 0 + +# for num, batch in enumerate(loader_train): +# data = { +# "Y": batch[0], +# "S": batch[1], +# "M": batch[2], +# "A": spatial_context[batch[-1], :], +# } + +# em_train_loss = train_test_run( +# data=data, +# spatial_context=data["A"], +# method="EM", +# method_enc=enc_em, +# method_dec=dec_em, +# n_proteins=N_PROTEINS, +# latent_dim=LATENT_DIM, +# corr_weights=corr_weights, +# recon_weights=r, +# beta=BETA, +# ) + +# ## Update gradients and weights +# em_train_loss[0].backward() + +# torch.nn.utils.clip_grad_norm_(parameters, 2.0) + +# optimizer_em.step() + +# train_losses_em += em_train_loss[0].detach().item() +# num_batches_em += 1 + + +# # Update old mean embeddings (once per epoch) +# with torch.no_grad(): +# mu_z_old, _, z1 = enc_em(data_train.Y, data_train.S, data_train.M, z1) # mu_z_old) + +# wandb.log( +# { +# "train_neg_elbo": train_losses_em / num_batches_em, +# "kl_div": em_train_loss[1], +# "recon_lik": em_train_loss[2], +# "recon_lik_me": em_train_loss[3], +# "recon_lik_corr": em_train_loss[4], +# "recon_lik_mor": em_train_loss[5], +# "recon_lik_spcont": em_train_loss[6], +# "mu_z_max": em_train_loss[8], +# "log_std_z_max": em_train_loss[9], +# "mu_z_min": em_train_loss[10], +# "log_std_z_min": em_train_loss[11], +# "mu_x_exp_hat_max": em_train_loss[12], +# "log_std_x_exp_hat_max": em_train_loss[13], +# "mu_x_exp_hat_min": em_train_loss[14], +# "log_std_x_exp_hat_min": em_train_loss[15], +# "mu_x_corr_hat_max": em_train_loss[16], +# "log_std_x_corr_hat_max": em_train_loss[17], +# "mu_x_corr_hat_min": em_train_loss[18], +# "log_std_x_corr_hat_min": em_train_loss[19], +# "mu_x_morph_hat_max": em_train_loss[20], +# "log_std_x_morph_hat_max": em_train_loss[21], +# "mu_x_morph_hat_min": em_train_loss[22], +# "log_std_x_morph_hat_min": em_train_loss[23], +# "mu_x_spcont_hat_max": em_train_loss[24], +# "log_std_x_spcont_hat_max": em_train_loss[25], +# "mu_x_spcont_hat_min": em_train_loss[26], +# "log_std_x_spcont_hat_min": em_train_loss[27], +# } +# ) + +# # Now compute test metrics + +# with torch.no_grad(): +# spatial_context_test = torch.smm( +# adj_mat_test_tensor, z1_test # mu_z_old_test +# ).to_dense() + +# test_data = { +# "Y": data_test.Y, +# "S": data_test.S, +# "M": data_test.M, +# "A": spatial_context_test, +# } + +# em_test_loss = train_test_run( +# data=test_data, +# spatial_context=test_data["A"], +# method="EM", +# method_enc=enc_em, +# method_dec=dec_em, +# n_proteins=N_PROTEINS, +# latent_dim=LATENT_DIM, +# corr_weights=corr_weights, +# recon_weights=r, +# beta=BETA, +# ) + +# # mu_z_old_test = em_test_loss[7] +# z1_test = em_test_loss[7] + +# wandb.log( +# { +# "test_neg_elbo": em_test_loss[0], +# "test_kl_div": em_test_loss[1], +# "test_recon_lik": em_test_loss[2], +# "test_recon_lik_me": em_test_loss[3], +# "test_recon_lik_corr": em_test_loss[4], +# "test_recon_lik_mor": em_test_loss[5], +# "test_recon_lik_spcont": em_test_loss[6], +# } +# ) + +# stop_time_em = time.time() +# wandb.finish() +# print("emVAE done training") + +# # adata_train.obsm['emVAE_final_spatial_context'] = mu_z_old_train_em.detach().numpy() +# # adata_test.obsm['emVAE_final_spatial_context'] = mu_z_old_test_em.detach().numpy() + +# # eval_em_spatial_context= torch.smm( +# # adj_mat_test_tensor, mu_z_old_test_em +# # ).to_dense() + +# with torch.no_grad(): +# mu_z, _, z1 = enc_em(data_train.Y, data_train.S, data_train.M, z1) +# mu_z_test, _, z1_test = enc_em(data_test.Y, data_test.S, data_test.M, z1_test) + +# adata_train.obsm["VAE"] = mu_z.numpy() +# adata_test.obsm["VAE"] = mu_z_test.numpy() + +# adata_train.obsm["spatial_context"] = z1.numpy() +# adata_test.obsm["spatial_context"] = z1_test.numpy() + +# adata_train.write(os.path.join(output_dir, "adata_train.h5ad")) +# adata_test.write(os.path.join(output_dir, "adata_test.h5ad")) + +# torch.save(enc_em.state_dict(), os.path.join(output_dir, "emVAE_encoder.pt")) +# torch.save(dec_em.state_dict(), os.path.join(output_dir, "emVAE_decoder.pt")) + + +# print("All done!")