diff --git a/tools/models/metrics/run-scib.py b/tools/models/metrics/run-scib.py index 5d35a2f70..5627599d5 100644 --- a/tools/models/metrics/run-scib.py +++ b/tools/models/metrics/run-scib.py @@ -1,9 +1,9 @@ import datetime +import functools import itertools import pickle import sys import warnings -from typing import List import cellxgene_census import numpy as np @@ -13,88 +13,80 @@ import scib_metrics import tiledbsoma as soma import yaml - -import numpy as np -import scipy as sp -import cellxgene_census -import functools - -from sklearn.preprocessing import LabelEncoder -from sklearn.model_selection import train_test_split -from sklearn.linear_model import LogisticRegression from sklearn import svm from sklearn.ensemble import RandomForestClassifier - +from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, roc_auc_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder warnings.filterwarnings("ignore") -class CensusClassifierMetrics: +class CensusClassifierMetrics: def __init__(self): self._default_metric = "accuracy" - def lr_labels(self, X, labels, metric = None): + def lr_labels(self, X, labels, metric=None): return self._base_accuracy(X, labels, LogisticRegression, metric=metric) - def svm_svc_labels(self, X, labels, metric = None): + def svm_svc_labels(self, X, labels, metric=None): return self._base_accuracy(X, labels, svm.SVC, metric=metric) - def random_forest_labels(self, X, labels, metric = None, n_jobs=8): + def random_forest_labels(self, X, labels, metric=None, n_jobs=8): return self._base_accuracy(X, labels, RandomForestClassifier, metric=metric, n_jobs=n_jobs) - def lr_batch(self, X, batch, metric = None): - return 1-self._base_accuracy(X, batch, LogisticRegression, metric=metric) + def lr_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, LogisticRegression, metric=metric) - def svm_svc_batch(self, X, batch, metric = None): - return 1-self._base_accuracy(X, batch, svm.SVC, metric=metric) + def svm_svc_batch(self, X, batch, metric=None): + return 1 - self._base_accuracy(X, batch, svm.SVC, metric=metric) - def random_forest_batch(self, X, batch, metric = None, n_jobs=8): - return 1-self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) + def random_forest_batch(self, X, batch, metric=None, n_jobs=8): + return 1 - self._base_accuracy(X, batch, RandomForestClassifier, metric=metric, n_jobs=n_jobs) def _base_accuracy(self, X, y, model, metric, test_size=0.4, **kwargs): - """ - Train LogisticRegression on X with labels y and return classifier accuracy score - """ + """Train LogisticRegression on X with labels y and return classifier accuracy score""" y_encoded = LabelEncoder().fit_transform(y) - X_train, X_test, y_train, y_test = train_test_split( - X, y_encoded, test_size=test_size, random_state=42 - ) + X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=test_size, random_state=42) model = model(**kwargs).fit(X_train, y_train) if metric == None: metric = self._default_metric - - if metric == "roc_auc": - #return y_test - #return model.predict_proba(X_test) + + if metric == "roc_auc": + # return y_test + # return model.predict_proba(X_test) return roc_auc_score(y_test, model.predict_proba(X_test), multi_class="ovo", average="macro") elif metric == "accuracy": return accuracy_score(y_test, model.predict(X_test)) else: raise ValueError("Only {'accuracy', 'roc_auc'} are supported as a metric") - + + def safelog(a): - return np.log(a, out=np.zeros_like(a), where=(a!=0)) + return np.log(a, out=np.zeros_like(a), where=(a != 0)) + -def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors = 100): +def nearest_neighbors_hnsw(x, ef=200, M=48, n_neighbors=100): import hnswlib + labels = np.arange(x.shape[0]) - p = hnswlib.Index(space = 'l2', dim = x.shape[1]) - p.init_index(max_elements = x.shape[0], ef_construction = ef, M = M) + p = hnswlib.Index(space="l2", dim=x.shape[1]) + p.init_index(max_elements=x.shape[0], ef_construction=ef, M=M) p.add_items(x, labels) p.set_ef(ef) - idx, dist = p.knn_query(x, k = n_neighbors) - return idx,dist + idx, dist = p.knn_query(x, k=n_neighbors) + return idx, dist -def compute_entropy_per_cell(adata, obsm_key): +def compute_entropy_per_cell(adata, obsm_key): batch_keys = ["dataset_id", "assay", "suspension_type"] - adata.obs["batch"] = functools.reduce(lambda a, b: a+b, [adata.obs[c].astype(str) for c in batch_keys]) + adata.obs["batch"] = functools.reduce(lambda a, b: a + b, [adata.obs[c].astype(str) for c in batch_keys]) - indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors = 200) + indices, dist = nearest_neighbors_hnsw(adata.obsm[obsm_key], n_neighbors=200) - BATCH_KEY = 'batch' + BATCH_KEY = "batch" batch_labels = np.array(list(adata.obs[BATCH_KEY])) unique_batch_labels = np.unique(batch_labels) @@ -102,8 +94,9 @@ def compute_entropy_per_cell(adata, obsm_key): indices_batch = batch_labels[indices] label_counts_per_cell = np.vstack([(indices_batch == label).sum(1) for label in unique_batch_labels]).T - label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:,None] - return (-label_counts_per_cell_normed*safelog(label_counts_per_cell_normed)).sum(1) + label_counts_per_cell_normed = label_counts_per_cell / label_counts_per_cell.sum(1)[:, None] + return (-label_counts_per_cell_normed * safelog(label_counts_per_cell_normed)).sum(1) + if __name__ == "__main__": try: @@ -162,16 +155,15 @@ def class_mapper(): subclass_dict = subclass_mapper() def build_anndata_with_embeddings( - embedding_names: List[str], + embedding_names: list[str], embeddings_raw: dict, - coords: List[int] = None, + coords: list[int] = None, obs_value_filter: str = None, column_names=dict, census_version: str = None, experiment_name: str = None, ): - """ - For a given set of Census cell coordinates (soma_joinids) + """For a given set of Census cell coordinates (soma_joinids) fetch embeddings with TileDBSoma and return the corresponding AnnData with embeddings slotted in. @@ -181,7 +173,6 @@ def build_anndata_with_embeddings( Assume that all embeddings provided are coming from the same experiment. """ - with cellxgene_census.open_soma(census_version=census_version) as census: print("Getting anndata with Census embeddings: ", embedding_names) @@ -253,7 +244,6 @@ def build_anndata_with_embeddings( batch_metrics = metrics_config["batch"] for tissue_node in tissues: - tissue = tissue_node["name"] query = tissue_node.get("query") or f"tissue_general == '{tissue}' and is_primary_data == True" @@ -363,13 +353,12 @@ def __init__(self, conn): if "classifier" in bio_metrics: metrics = CensusClassifierMetrics() - m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) - m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) - m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels = adata_metrics.obs["cell_type"]) + m1 = metrics.lr_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m2 = metrics.svm_svc_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) + m3 = metrics.random_forest_labels(X=adata_metrics.obsm[emb], labels=adata_metrics.obs[bio_label]) metric_bio_results["classifier"].append({"lr": m1, "svm": m2, "random_forest": m3}) - for batch_label, emb in itertools.product(batch_labels, embs): print("\n\nSTART", batch_label, emb) @@ -402,9 +391,9 @@ def __init__(self, conn): if "classifier" in batch_metrics: metrics = CensusClassifierMetrics() - m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) - m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) - m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch = adata_metrics.obs[batch_label]) + m4 = metrics.lr_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m5 = metrics.random_forest_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) + m6 = metrics.svm_svc_batch(X=adata_metrics.obsm[emb], batch=adata_metrics.obs[batch_label]) metric_batch_results["classifier"].append({"lr": m4, "random_forest": m5, "svm": m6}) if "entropy" in batch_metrics: @@ -414,8 +403,9 @@ def __init__(self, conn): e_mean = entropy.mean() metric_batch_results["entropy"].append(e_mean) - all_bio[tissue] = metric_bio_results - all_batch[tissue] = metric_batch_results + filename = f"metrics.{tissue}.pickle".replace(" ", "-").lower() - with open("metrics.pickle", "wb") as fp: - pickle.dump({"all_bio": all_bio, "all_batch": all_batch}, fp, protocol=pickle.HIGHEST_PROTOCOL) \ No newline at end of file + with open(filename, "wb") as fp: + pickle.dump( + {"bio": metric_bio_results, "batch": metric_batch_results}, fp, protocol=pickle.HIGHEST_PROTOCOL + ) diff --git a/tools/models/metrics/scib-metrics-config.yaml b/tools/models/metrics/scib-metrics-config.yaml index 94eef9085..2ba7d30eb 100644 --- a/tools/models/metrics/scib-metrics-config.yaml +++ b/tools/models/metrics/scib-metrics-config.yaml @@ -15,5 +15,4 @@ metrics: bio: ["leiden_nmi", "leiden_ari", "silhouette_label", "classifier"] batch: - ["silhouette_batch", "ilisi_knn_batch", "classifier"] - + ["silhouette_batch", "ilisi_knn_batch", "classifier", "entropy"] \ No newline at end of file