Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzi committed Jul 5, 2024
1 parent 6798b0c commit 88936dc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 63 deletions.
112 changes: 51 additions & 61 deletions tools/models/metrics/run-scib.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime
import functools
import itertools
import pickle
import sys
import warnings
from typing import List

import cellxgene_census
import numpy as np
Expand All @@ -13,97 +13,90 @@
import scib_metrics
import tiledbsoma as soma
import yaml

import numpy as np
import scipy as sp
import cellxgene_census
import functools

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

warnings.filterwarnings("ignore")

class CensusClassifierMetrics:

class CensusClassifierMetrics:
def __init__(self):
self._default_metric = "accuracy"

def lr_labels(self, X, labels, metric = None):
def lr_labels(self, X, labels, metric=None):
return self._base_accuracy(X, labels, LogisticRegression, metric=metric)

def svm_svc_labels(self, X, labels, metric = None):
def svm_svc_labels(self, X, labels, metric=None):
return self._base_accuracy(X, labels, svm.SVC, metric=metric)

def random_forest_labels(self, X, labels, metric = None, n_jobs=8):
def random_forest_labels(self, X, labels, metric=None, n_jobs=8):
return self._base_accuracy(X, labels, RandomForestClassifier, metric=metric, n_jobs=n_jobs)

def lr_batch(self, X, batch, metric = None):
return 1-self._base_accuracy(X, batch, LogisticRegression, metric=metric)
def lr_batch(self, X, batch, metric=None):
return 1 - self._base_accuracy(X, batch, LogisticRegression, metric=metric)

def svm_svc_batch(self, X, batch, metric = None):
return 1-self._base_accuracy(X, batch, svm.SVC, metric=metric)
def svm_svc_batch(self, X, batch, metric=None):
return 1 - self._base_accuracy(X, batch, svm.SVC, metric=metric)

def random_forest_batch(self, X, batch, metric = None, n_jobs=8):
return 1-self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs)
def random_forest_batch(self, X, batch, metric=None, n_jobs=8):
return 1 - self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs)

def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs):
"""
Train LogisticRegression on X with labels y and return classifier accuracy score
"""
"""Train LogisticRegression on X with labels y and return classifier accuracy score"""
y_encoded = LabelEncoder().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(
X, y_encoded, test_size=test_size, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=test_size, random_state=42)
model = model(**kwargs).fit(X_train, y_train)

if metric == None:
metric = self._default_metric
if metric == "roc_auc":
#return y_test
#return model.predict_proba(X_test)

if metric == "roc_auc":
# return y_test
# return model.predict_proba(X_test)
return roc_auc_score(y_test, model.predict_proba(X_test), multi_class="ovo", average="macro")
elif metric == "accuracy":
return accuracy_score(y_test, model.predict(X_test))
else:
raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric")



def safelog(a):
return np.log(a, out=np.zeros_like(a), where=(a!=0))
return np.log(a, out=np.zeros_like(a), where=(a != 0))


def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors = 100):
def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors=100):
import hnswlib

labels = np.arange(x.shape[0])
p = hnswlib.Index(space = 'l2', dim = x.shape[1])
p.init_index(max_elements = x.shape[0], ef_construction = ef, M = M)
p = hnswlib.Index(space="l2", dim=x.shape[1])
p.init_index(max_elements=x.shape[0], ef_construction=ef, M=M)
p.add_items(x, labels)
p.set_ef(ef)
idx, dist = p.knn_query(x, k = n_neighbors)
return idx,dist
idx, dist = p.knn_query(x, k=n_neighbors)
return idx, dist

def compute_entropy_per_cell(adata, obsm_key):

def compute_entropy_per_cell(adata, obsm_key):
batch_keys = ["dataset_id", "assay", "suspension_type"]
adata.obs["batch"] = functools.reduce(lambda a, b: a+b, [adata.obs[c].astype(str) for c in batch_keys])
adata.obs["batch"] = functools.reduce(lambda a, b: a + b, [adata.obs[c].astype(str) for c in batch_keys])

indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors = 200)
indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors=200)

BATCH_KEY = 'batch'
BATCH_KEY = "batch"

batch_labels = np.array(list(adata.obs[BATCH_KEY]))
unique_batch_labels = np.unique(batch_labels)

indices_batch = batch_labels[indices]

label_counts_per_cell = np.vstack([(indices_batch == label).sum(1) for label in unique_batch_labels]).T
label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:,None]
return (-label_counts_per_cell_normed*safelog(label_counts_per_cell_normed)).sum(1)
label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:, None]
return (-label_counts_per_cell_normed * safelog(label_counts_per_cell_normed)).sum(1)


if __name__ == "__main__":
try:
Expand Down Expand Up @@ -162,16 +155,15 @@ def class_mapper():
subclass_dict = subclass_mapper()

def build_anndata_with_embeddings(
embedding_names: List[str],
embedding_names: list[str],
embeddings_raw: dict,
coords: List[int] = None,
coords: list[int] = None,
obs_value_filter: str = None,
column_names=dict,
census_version: str = None,
experiment_name: str = None,
):
"""
For a given set of Census cell coordinates (soma_joinids)
"""For a given set of Census cell coordinates (soma_joinids)
fetch embeddings with TileDBSoma and return the corresponding
AnnData with embeddings slotted in.
Expand All @@ -181,7 +173,6 @@ def build_anndata_with_embeddings(
Assume that all embeddings provided are coming from the same experiment.
"""

with cellxgene_census.open_soma(census_version=census_version) as census:
print("Getting anndata with Census embeddings: ", embedding_names)

Expand Down Expand Up @@ -253,7 +244,6 @@ def build_anndata_with_embeddings(
batch_metrics = metrics_config["batch"]

for tissue_node in tissues:

tissue = tissue_node["name"]
query = tissue_node.get("query") or f"tissue_general == '{tissue}' and is_primary_data == True"

Expand Down Expand Up @@ -363,13 +353,12 @@ def __init__(self, conn):
if "classifier" in bio_metrics:
metrics = CensusClassifierMetrics()

m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"])
m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"])
m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"])
m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label])
m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label])
m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label])

metric_bio_results["classifier"].append({"lr": m1, "svm": m2, "random_forest": m3})


for batch_label, emb in itertools.product(batch_labels, embs):
print("\n\nSTART", batch_label, emb)

Expand Down Expand Up @@ -402,9 +391,9 @@ def __init__(self, conn):
if "classifier" in batch_metrics:
metrics = CensusClassifierMetrics()

m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label])
m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label])
m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label])
m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label])
m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label])
m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label])
metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6})

if "entropy" in batch_metrics:
Expand All @@ -414,8 +403,9 @@ def __init__(self, conn):
e_mean = entropy.mean()
metric_batch_results["entropy"].append(e_mean)

all_bio[tissue] = metric_bio_results
all_batch[tissue] = metric_batch_results
filename = f"metrics.{tissue}.pickle".replace(" ", "-").lower()

with open("metrics.pickle", "wb") as fp:
pickle.dump({"all_bio": all_bio, "all_batch": all_batch}, fp, protocol=pickle.HIGHEST_PROTOCOL)
with open(filename, "wb") as fp:
pickle.dump(
{"bio": metric_bio_results, "batch": metric_batch_results}, fp, protocol=pickle.HIGHEST_PROTOCOL
)
3 changes: 1 addition & 2 deletions tools/models/metrics/scib-metrics-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ metrics:
bio:
["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"]
batch:
["silhouette_batch", "ilisi_knn_batch", "classifier"]

["silhouette_batch", "ilisi_knn_batch", "classifier", "entropy"]

0 comments on commit 88936dc

Please sign in to comment.