Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hmi vae init #1

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
one-hot encoding added to train and test v1
shanzaayub committed Jun 20, 2022
commit 1711e6f23284351aaf07adc2e5aa3e82816a12f2
2 changes: 1 addition & 1 deletion hmivae/ScModeDataloader.py
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ def one_hot_encoding(self, test=False):

df = df.reindex(columns=self.adata.obs.Sample_name.unique().tolist())

return torch.tensor(df.to_numpy())
return torch.tensor(df.to_numpy()).float()

def get_spatial_context(self):
adj_mat = sparse_numpy_to_torch(
95 changes: 66 additions & 29 deletions hmivae/_hmivae_base_components.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -19,19 +21,22 @@ class EncoderHMIVAE(nn.Module):

def __init__(
self,
input_exp_dim,
input_corr_dim,
input_morph_dim,
input_spcont_dim,
E_me,
E_cr,
E_mr,
E_sc,
latent_dim,
n_hidden=1,
input_exp_dim: int,
input_corr_dim: int,
input_morph_dim: int,
input_spcont_dim: int,
E_me: int,
E_cr: int,
E_mr: int,
E_sc: int,
latent_dim: int,
n_covariates: Optional[int] = 0,
n_hidden: Optional[int] = 1,
):
super().__init__()
hidden_dim = E_me + E_cr + E_mr + E_sc
hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates

self.input_cov = nn.Linear(n_covariates, n_covariates)

self.input_exp = nn.Linear(input_exp_dim, E_me)
self.exp_hidden = nn.Linear(E_me, E_me)
@@ -50,7 +55,14 @@ def __init__(
self.mu_z = nn.Linear(hidden_dim, latent_dim)
self.std_z = nn.Linear(hidden_dim, latent_dim)

def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context):
def forward(
self,
x_mean: torch.Tensor,
x_correlations: torch.Tensor,
x_morphology: torch.Tensor,
x_spatial_context: torch.Tensor,
cov_list=torch.Tensor([]),
):
h_mean = F.elu(self.input_exp(x_mean))
h_mean2 = F.elu(self.exp_hidden(h_mean))

@@ -60,12 +72,14 @@ def forward(self, x_mean, x_correlations, x_morphology, x_spatial_context):
h_morphology = F.elu(self.input_morph(x_morphology))
h_morphology2 = F.elu(self.morph_hidden(h_morphology))

# z1 = torch.cat([h_mean2, h_correlations2, h_morphology2], 1)

h_spatial_context = F.elu(self.input_spatial_context(x_spatial_context))
h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context))

h = torch.cat([h_mean2, h_correlations2, h_morphology2, h_spatial_context2], 1)
h_cov = F.elu(self.input_cov(cov_list))

h = torch.cat(
[h_mean2, h_correlations2, h_morphology2, h_spatial_context2, h_cov], 1
)

for net in self.linear:
h = F.elu(net(h))
@@ -94,19 +108,21 @@ class DecoderHMIVAE(nn.Module):

def __init__(
self,
latent_dim,
E_me,
E_cr,
E_mr,
E_sc,
input_exp_dim,
input_corr_dim,
input_morph_dim,
input_spcont_dim,
n_hidden=1,
latent_dim: int,
E_me: int,
E_cr: int,
E_mr: int,
E_sc: int,
input_exp_dim: int,
input_corr_dim: int,
input_morph_dim: int,
input_spcont_dim: int,
n_covariates: Optional[int] = 0,
n_hidden: Optional[int] = 1,
):
super().__init__()
hidden_dim = E_me + E_cr + E_mr + E_sc
hidden_dim = E_me + E_cr + E_mr + E_sc + n_covariates
latent_dim = latent_dim + n_covariates
self.E_me = E_me
self.E_cr = E_cr
self.E_mr = E_mr
@@ -135,8 +151,12 @@ def __init__(
self.mu_x_spcont = nn.Linear(E_sc, input_spcont_dim)
self.std_x_spcont = nn.Linear(E_sc, input_spcont_dim)

def forward(self, z):
out = F.elu(self.input(z))
self.covariates_out_mu = nn.Linear(n_covariates, n_covariates)
self.covariates_out_std = nn.Linear(n_covariates, n_covariates)

def forward(self, z, cov_list):
z_s = torch.cat([z, cov_list], 1)
out = F.elu(self.input(z_s))
for net in self.linear:
out = F.elu(net(out))

@@ -150,9 +170,21 @@ def forward(self, z):
)
)
h2_spatial_context = F.elu(
self.spatial_context_hidden(out[:, self.E_me + self.E_cr + self.E_mr :])
self.spatial_context_hidden(
out[
:,
self.E_me
+ self.E_cr
+ self.E_mr : self.E_me
+ self.E_cr
+ self.E_mr
+ self.E_sc,
]
)
)

covariates = out[:, self.E_me + self.E_cr + self.E_mr + self.E_sc :]

mu_x_exp = self.mu_x_exp(h2_mean)
std_x_exp = self.std_x_exp(h2_mean)

@@ -173,6 +205,9 @@ def forward(self, z):
mu_x_spatial_context = self.mu_x_spcont(h2_spatial_context)
std_x_spatial_context = self.std_x_spcont(h2_spatial_context)

covariates_mu = self.covariates_out_mu(covariates)
covariates_std = self.covariates_out_std(covariates)

return (
mu_x_exp,
std_x_exp,
@@ -182,5 +217,7 @@ def forward(self, z):
std_x_morph,
mu_x_spatial_context,
std_x_spatial_context,
covariates_mu,
covariates_std,
# weights,
)
62 changes: 50 additions & 12 deletions hmivae/_hmivae_model.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,10 @@

import numpy as np
import pytorch_lightning as pl
import torch
from _hmivae_module import hmiVAE
from anndata import AnnData
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.trainer import Trainer
from scipy.stats.mstats import winsorize
from ScModeDataloader import ScModeDataloader
@@ -60,17 +62,14 @@ def __init__(
E_mr: int = 32,
E_sc: int = 32,
latent_dim: int = 10,
n_covariates: int = 0,
n_hidden: int = 1,
**model_kwargs,
):
# super(hmivaeModel, self).__init__(adata)
super().__init__()

# library_log_means, library_log_vars = _init_library_size(
# adata, self.summary_stats["n_batch"]
# )

self.train_batch, self.test_batch = self.setup_anndata(
self.train_batch, self.test_batch, self.n_covariates = self.setup_anndata(
adata=adata,
protein_correlations_obsm_key="correlations",
cell_morphology_obsm_key="morphology",
@@ -87,18 +86,21 @@ def __init__(
E_mr=E_mr,
E_sc=E_sc,
latent_dim=latent_dim,
n_covariates=self.n_covariates,
n_hidden=n_hidden,
**model_kwargs,
)
self._model_summary_string = (
"hmiVAE model with the following parameters: \nn_latent:{}"
"n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}"
"hmiVAE model with the following parameters: \n n_latent:{}, "
"n_protein_expression:{}, n_correlation:{}, n_morphology:{}, n_spatial_context:{}, "
"n_covariates:{} "
).format(
latent_dim,
input_exp_dim,
input_corr_dim,
input_morph_dim,
input_spcont_dim,
n_covariates,
)
# necessary line to get params that will be used for saving/loading
# self.init_params_ = self._get_init_params(locals())
@@ -109,13 +111,45 @@ def train(
self,
): # misnomer, both train and test are here (either rename or separate)

trainer = Trainer(max_epochs=10)
early_stopping = EarlyStopping(
monitor="test_loss", mode="min", patience=3
) # need to add this

trainer = Trainer(max_epochs=10, callbacks=[early_stopping])

trainer.fit(self.module, self.train_batch) # training, add wandb
trainer.test(dataloaders=self.test_batch) # test, add wandb
trainer.fit(
self.module, self.train_batch, self.test_batch
) # training, add wandb
# trainer.test(dataloaders=self.test_batch) # test, add wandb

# return trainer

@torch.no_grad()
def get_latent_representation(
self,
adata: AnnData,
protein_correlations_obsm_key: str,
cell_morphology_obsm_key: str,
is_trained_model: Optional[bool] = True,
) -> np.ndarray:
"""
Gives the latent representation of each cell.
"""
if is_trained_model:
data_train, data_test = self.setup_anndata(
adata,
protein_correlations_obsm_key,
cell_morphology_obsm_key,
is_trained_model=is_trained_model,
)
train_mu_z = self.module.inference(data_train)
test_mu_z = self.module.inference(data_test)
return train_mu_z, test_mu_z
else:
raise Exception(
"No latent representation to produce! Model is not trained!"
)

# @setup_anndata_dsp.dedent
@staticmethod
def setup_anndata(
@@ -136,6 +170,7 @@ def setup_anndata(
train_prop: Optional[float] = 0.75,
apply_winsorize: Optional[bool] = True,
arctanh_corrs: Optional[bool] = False,
is_trained_model: Optional[bool] = False,
random_seed: Optional[int] = 1234,
copy: bool = False,
) -> Optional[AnnData]:
@@ -196,6 +231,9 @@ def setup_anndata(
data_test = ScModeDataloader(adata_test, data_train.scalers)

loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(data_test, batch_size=batch_size) # shuffle=True)

return loader_train, loader_test
if is_trained_model:
return data_train, data_test
else:
return loader_train, loader_test, len(samples_train)
241 changes: 175 additions & 66 deletions hmivae/_hmivae_module.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Iterable, Optional # Dict, Tuple, Union
from typing import List, Optional, Sequence

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from _hmivae_base_components import DecoderHMIVAE, EncoderHMIVAE

# from anndata import AnnData

torch.backends.cudnn.benchmark = True


@@ -25,10 +27,12 @@ def __init__(
E_mr: int = 32,
E_sc: int = 32,
latent_dim: int = 10,
n_covariates: int = 0,
n_hidden: int = 1,
):
super().__init__()
# hidden_dim = E_me + E_cr + E_mr + E_sc
self.n_covariates = n_covariates

self.encoder = EncoderHMIVAE(
input_exp_dim,
@@ -40,6 +44,7 @@ def __init__(
E_mr,
E_sc,
latent_dim,
n_covariates=n_covariates,
)

self.decoder = DecoderHMIVAE(
@@ -52,6 +57,7 @@ def __init__(
input_corr_dim,
input_morph_dim,
input_spcont_dim,
n_covariates=n_covariates,
)

def reparameterization(self, mu, log_std):
@@ -91,10 +97,13 @@ def em_recon_loss(
dec_x_logstd_morph,
dec_x_mu_spcont,
dec_x_logstd_spcont,
covariates_mu,
covariates_std,
y,
s,
m,
c,
cov_list,
# weights=None,
):
"""Takes in the parameters output from the decoder,
@@ -119,17 +128,20 @@ def em_recon_loss(
dec_x_std_corr = torch.exp(dec_x_logstd_corr)
dec_x_std_morph = torch.exp(dec_x_logstd_morph)
dec_x_std_spcont = torch.exp(dec_x_logstd_spcont)
cov_std = torch.exp(covariates_std)
p_rec_exp = torch.distributions.Normal(dec_x_mu_exp, dec_x_std_exp + 1e-6)
p_rec_corr = torch.distributions.Normal(dec_x_mu_corr, dec_x_std_corr + 1e-6)
p_rec_morph = torch.distributions.Normal(dec_x_mu_morph, dec_x_std_morph + 1e-6)
p_rec_spcont = torch.distributions.Normal(
dec_x_mu_spcont, dec_x_std_spcont + 1e-6
)
p_rec_cov = torch.distributions.Normal(covariates_mu, cov_std + 1e-6)

log_p_xz_exp = p_rec_exp.log_prob(y)
log_p_xz_corr = p_rec_corr.log_prob(s)
log_p_xz_morph = p_rec_morph.log_prob(m)
log_p_xz_spcont = p_rec_spcont.log_prob(c) # already dense matrix
log_p_cov = p_rec_cov.log_prob(cov_list)

# if weights is None:
# log_p_xz_corr = p_rec_corr.log_prob(s)
@@ -142,8 +154,9 @@ def em_recon_loss(
log_p_xz_corr = log_p_xz_corr.sum(-1)
log_p_xz_morph = log_p_xz_morph.sum(-1)
log_p_xz_spcont = log_p_xz_spcont.sum(-1)
log_p_cov = log_p_cov.sum(-1)

return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont
return log_p_xz_exp, log_p_xz_corr, log_p_xz_morph, log_p_xz_spcont, log_p_cov

def neg_ELBO(
self,
@@ -157,16 +170,25 @@ def neg_ELBO(
dec_x_logstd_morph,
dec_x_mu_spcont,
dec_x_logstd_spcont,
covariates_mu,
covariates_std,
z,
y,
s,
m,
c,
cov_list,
# weights=None,
):
kl_div = self.KL_div(enc_x_mu, enc_x_logstd, z)

recon_lik_me, recon_lik_corr, recon_lik_mor, recon_lik_sc = self.em_recon_loss(
(
recon_lik_me,
recon_lik_corr,
recon_lik_mor,
recon_lik_sc,
reconstructed_covs,
) = self.em_recon_loss(
dec_x_mu_exp,
dec_x_logstd_exp,
dec_x_mu_corr,
@@ -175,10 +197,13 @@ def neg_ELBO(
dec_x_logstd_morph,
dec_x_mu_spcont,
dec_x_logstd_spcont,
covariates_mu,
covariates_std,
y,
s,
m,
c,
cov_list,
# weights,
)
return (
@@ -187,6 +212,7 @@ def neg_ELBO(
recon_lik_corr,
recon_lik_mor,
recon_lik_sc,
reconstructed_covs,
)

def loss(self, kl_div, recon_loss, beta: float = 1.0):
@@ -196,11 +222,10 @@ def loss(self, kl_div, recon_loss, beta: float = 1.0):
def training_step(
self,
train_batch,
batch_idx,
categories: Optional[Iterable[int]] = None,
corr_weights=False,
recon_weights=np.array([1.0, 1.0, 1.0, 1.0]),
beta=1.0,
categories: Optional[List[float]] = None,
):
"""
Carries out the training step.
@@ -214,8 +239,19 @@ def training_step(
S = train_batch[1]
M = train_batch[2]
spatial_context = train_batch[3]

mu_z, log_std_z = self.encoder(Y, S, M, spatial_context)
one_hot = train_batch[4]
batch_idx = train_batch[-1]
if categories is not None:
if len(categories) > 0:
categories = torch.Tensor(categories)[batch_idx, :]
else:
categories = torch.Tensor(categories)
else:
categories = torch.Tensor([])

cov_list = torch.cat([one_hot, categories], 1).float()
# print('train',cov_list.size())
mu_z, log_std_z = self.encoder(Y, S, M, spatial_context, cov_list)

z_samples = self.reparameterization(mu_z, log_std_z)

@@ -229,15 +265,18 @@ def training_step(
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
covariates_mu,
covariates_std,
# weights,
) = self.decoder(z_samples)
) = self.decoder(z_samples, cov_list)

(
kl_div,
recon_lik_me,
recon_lik_corr,
recon_lik_mor,
recon_lik_sc,
reconstructed_covs,
) = self.neg_ELBO(
mu_z,
log_std_z,
@@ -249,11 +288,14 @@ def training_step(
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
covariates_mu,
covariates_std,
z_samples,
Y,
S,
M,
spatial_context,
cov_list,
# weights,
)

@@ -262,6 +304,7 @@ def training_step(
+ recon_weights[1] * recon_lik_corr
+ recon_weights[2] * recon_lik_mor
+ recon_weights[3] * recon_lik_sc
+ reconstructed_covs
)

loss = self.loss(kl_div, recon_loss, beta=beta)
@@ -278,15 +321,17 @@ def training_step(
"recon_lik_sc": recon_lik_sc.mean().item(),
}

def test_step(
def validation_step(
self,
test_batch,
batch_idx,
n_other_cat: int = 0,
L_iter: int = 10,
corr_weights=False,
recon_weights=np.array([1.0, 1.0, 1.0, 1.0]),
beta=1.0,
categories: Optional[List[float]] = None,
):
"""
"""---> Add random one-hot encoding
Carries out the validation/test step.
test_batch: torch.Tensor. Validation/test data,
spatial_context: torch.Tensor. Matrix with old mu_z integrated neighbours information,
@@ -298,62 +343,99 @@ def test_step(
S = test_batch[1]
M = test_batch[2]
spatial_context = test_batch[3]

mu_z, log_std_z = self.encoder(Y, S, M, spatial_context)

z_samples = self.reparameterization(mu_z, log_std_z)

# decoding
(
mu_x_exp_hat,
log_std_x_exp_hat,
mu_x_corr_hat,
log_std_x_corr_hat,
mu_x_morph_hat,
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
# weights,
) = self.decoder(z_samples)

(
kl_div,
recon_lik_me,
recon_lik_corr,
recon_lik_mor,
recon_lik_sc,
) = self.neg_ELBO(
mu_z,
log_std_z,
mu_x_exp_hat,
log_std_x_exp_hat,
mu_x_corr_hat,
log_std_x_corr_hat,
mu_x_morph_hat,
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
z_samples,
Y,
S,
M,
spatial_context,
# weights,
)

recon_loss = (
recon_weights[0] * recon_lik_me
+ recon_weights[1] * recon_lik_corr
+ recon_weights[2] * recon_lik_mor
+ recon_weights[3] * recon_lik_sc
)

loss = self.loss(kl_div, recon_loss, beta=beta)

self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
batch_idx = test_batch[-1]
# print(batch_idx)
test_loss = []
n_classes = self.n_covariates # - n_other_cat
for i in range(L_iter):
# print(n_classes)
# print(len(batch_idx))
# print(np.eye(n_classes)[np.random.choice(n_classes, len(batch_idx))])
one_hot = self.random_one_hot(n_classes=n_classes, n_samples=len(batch_idx))
# print(one_hot.size())

if categories is not None:
if len(categories) > 0:
categories = torch.Tensor(categories)[batch_idx, :]
else:
categories = torch.Tensor(categories)
else:
categories = torch.Tensor([])

cov_list = torch.cat([one_hot, categories], 1).float()

mu_z, log_std_z = self.encoder(
Y, S, M, spatial_context, cov_list
) # valid step

z_samples = self.reparameterization(mu_z, log_std_z)

# decoding
(
mu_x_exp_hat,
log_std_x_exp_hat,
mu_x_corr_hat,
log_std_x_corr_hat,
mu_x_morph_hat,
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
covariates_mu,
covariates_std,
# weights,
) = self.decoder(z_samples, cov_list)

(
kl_div,
recon_lik_me,
recon_lik_corr,
recon_lik_mor,
recon_lik_sc,
reconstructed_covs,
) = self.neg_ELBO(
mu_z,
log_std_z,
mu_x_exp_hat,
log_std_x_exp_hat,
mu_x_corr_hat,
log_std_x_corr_hat,
mu_x_morph_hat,
log_std_x_morph_hat,
mu_x_spcont_hat,
log_std_x_spcont_hat,
covariates_mu,
covariates_std,
z_samples,
Y,
S,
M,
spatial_context,
cov_list,
# weights,
)

recon_loss = (
recon_weights[0] * recon_lik_me
+ recon_weights[1] * recon_lik_corr
+ recon_weights[2] * recon_lik_mor
+ recon_weights[3] * recon_lik_sc
+ reconstructed_covs
)

loss = self.loss(kl_div, recon_loss, beta=beta)

test_loss.append(loss)

self.log(
"test_loss",
sum(test_loss) / L_iter,
on_step=True,
on_epoch=True,
prog_bar=True,
) # log the average test loss over all the iterations

return {
"loss": loss,
"loss": sum(test_loss) / L_iter,
"kl_div": kl_div.mean().item(),
"recon_loss": recon_loss.mean().item(),
"recon_lik_me": recon_lik_me.mean().item(),
@@ -367,7 +449,8 @@ def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def __get_input_embeddings__(
@torch.no_grad()
def get_input_embeddings(
self, x_mean, x_correlations, x_morphology, x_spatial_context
):
"""
@@ -386,3 +469,29 @@ def __get_input_embeddings__(
h_spatial_context2 = F.elu(self.spatial_context_hidden(h_spatial_context))

return h_mean2, h_correlations2, h_morphology2, h_spatial_context2

@torch.no_grad()
def inference(
self, data, indices: Optional[Sequence[int]] = None, give_mean: bool = True
) -> np.ndarray:
"""
Return the latent representation of each cell.
"""
if give_mean:
mu_z, _ = self.encoder(data.Y, data.S, data.M, data.C)

return mu_z.numpy()
else:
mu_z, log_std_z = self.encoder(data.Y, data.S, data.M, data.C)
z = self.reparameterization(mu_z, log_std_z)

return z.numpy()

@torch.no_grad()
def random_one_hot(self, n_classes: int, n_samples: int):
"""
Generates a random one hot encoded matrix.
From: https://stackoverflow.com/questions/45093615/random-one-hot-matrix-in-numpy
"""
# x = np.eye(n_classes)
return torch.Tensor(np.eye(n_classes)[np.random.choice(n_classes, n_samples)])