-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathimc_exp_setup.py
120 lines (99 loc) · 4.01 KB
/
imc_exp_setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from scanpy import read_h5ad
import pandas as pd
import numpy as np
from sklearn import metrics
from sklearn.cluster import KMeans
import torch
import random
positive_coexpression_pairs = [
['ECadherin', 'panCytokeratin'],
['CD45', 'CD20'],
['CD45', 'CD68'],
['CD45', 'CD3'],
['Vimentin', 'Fibronectin'],
['CD19','CD20'],
['panCytokeratin', 'CK5']
]
# This function loads jackson-imc dataset.
def load_data():
# Load data
adata = read_h5ad("basel_zuri.h5ad")
adata = adata[adata.obs.index.str.contains("Basel"),:]
# Subsample
seed=10
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(a=seed)
to_sample = np.random.choice(adata.shape[0], size=5000, replace=False)
adata = adata[to_sample,:]
print(f"max in subsampled adata: {adata.X.max()}")
print(f"shape of adata: {adata.shape}")
# Load clusters
clusters = pd.read_csv("PG_final_k20.csv").set_index("id")
gt_clusters = list(clusters.PhenoGraphBasel[adata.obs_names])
adata.obs['target'] = gt_clusters
adata = adata[:,~adata.var_names.str.contains("Rutheni|80ArArArAr80Di|Hg202|I127|In115|208PbPb208Di|Xe126|Xe131|Pb20|Xe13")]
Y = adata.X.copy()
Y = np.arcsinh(Y)
return adata, Y
def split_data(adata, Y):
# Split subsampled data to train and test
train = np.random.choice(adata.shape[0], size=int(adata.shape[0]*0.7), replace=False)
adata_train = adata[train,:]
Y_train = Y[train,:]
test = set(range(adata.shape[0])).difference(set(train))
adata_test = adata[list(test),:]
Y_test = Y[list(test),:]
return adata_train, Y_train, adata_test, Y_test
# This function computes correlation between cluster
# mean expression.
def get_coexpression(cluster_means, p1, p2, var_names):
df = pd.DataFrame(cluster_means, columns=var_names)
return df.corr()[p1][p2]
# This function arcisnh-normalised the data with
# the given cofactor and clusters it with k-means.
def norm_adata(adata, cofactor=5., n_clusters=10):
# Normalise data with cofactor
adata.X = np.arcsinh(adata.X / cofactor)
km = KMeans(n_clusters, random_state=int(np.random.choice(1000, 1)))
# Cluster data
km.fit(adata.X)
adata.obs['leiden'] = km.labels_
return adata
# This function computes the clustering resulting from a given cofactor
# and computes objective values for it (supervised objectives are
# w.r.t. true labels).
def true_f(cofactors, adata, Y):
results = {}
from collections import defaultdict
results_list = defaultdict(list)
ari = []
nmi = []
for cofactor in cofactors:
results[cofactor] = {}
adata_norm = norm_adata(adata.copy(), cofactor)
unique_clusters = np.unique(adata_norm.obs.leiden)
cluster_means = np.concatenate([Y[adata_norm.obs.leiden == cl,:].mean(0).reshape(1,-1) for cl in unique_clusters], axis=0)
for pair in positive_coexpression_pairs:
pair_str = pair[0] + "_" + pair[1] + "_+"
results[cofactor][pair_str] = get_coexpression(cluster_means, pair[0], pair[1], adata_norm.var_names).astype(np.float32)
ari.append(metrics.adjusted_rand_score(adata_norm.obs.target.astype(str), adata_norm.obs.leiden))
nmi.append(metrics.normalized_mutual_info_score(adata_norm.obs.target.astype(str), adata_norm.obs.leiden))
for k,v in results[cofactor].items():
results_list[k].append(v)
pos_pairs = []
for pair in positive_coexpression_pairs:
pair_str = pair[0] + "_" + pair[1] + "_+"
obj = np.array(results_list[pair_str])
pos_pairs.append(torch.reshape(torch.tensor([obj]), (obj.shape[0], 1)))
ari = np.array(ari)
nmi = np.array(nmi)
ari = torch.reshape(torch.tensor([ari]), (ari.shape[0], 1))
nmi = torch.reshape(torch.tensor([nmi]), (nmi.shape[0], 1))
y = torch.cat(pos_pairs, axis=1)
return y, ari, nmi
def get_labels():
labels = []
for pair in positive_coexpression_pairs:
labels.append(pair[0] + "_" + pair[1] + "_+")
return labels