Skip to content

Commit

Permalink
Merge branch 'master' of github.com:linnarsson-lab/FISHscale
Browse files Browse the repository at this point in the history
  • Loading branch information
larsborm committed Apr 26, 2023
2 parents 440f234 + 7130826 commit caefc60
Show file tree
Hide file tree
Showing 78 changed files with 450 additions and 20 deletions.
Empty file modified .gitignore
100755 → 100644
Empty file.
Empty file modified FISHscale/__init__.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/__init__.py
100755 → 100644
Empty file.
58 changes: 46 additions & 12 deletions FISHscale/graphNN/cellularneighborhoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import pytorch_lightning as pl
import pandas as pd
import dgl
from FISHscale.graphNN.models import SAGELightning
from FISHscale.graphNN.models_deepresidual import SAGELightning
#from FISHscale.graphNN.models import SAGELightning
from FISHscale.graphNN.graph_utils import GraphUtils, GraphPlotting
from FISHscale.graphNN.graph_decoder import GraphDecoder

Expand Down Expand Up @@ -110,10 +111,9 @@ def __init__(self,

self.unique_labels = np.unique(anndata.obs[self.label_name].values)
anndata = anndata[(anndata[:, self.genes].X.sum(axis=1) > 5), :]
anndata.raw = anndata
#anndata.raw = anndata
if normalize:
sc.pp.normalize_total(anndata, target_sum=1e4)
sc.pp.log1p(anndata)
self.anndata = anndata

### Model hyperparameters
Expand Down Expand Up @@ -398,8 +398,8 @@ def get_latents(self):
labelled (bool, optional): [description]. Defaults to True.
"""
self.model.eval()

self.latent_unlabelled, prediction_unlabelled = self.model.module.inference(
#self.g.to('cuda')
self.latent_unlabelled, _ = self.model.module.inference(
self.g,
self.model.device,
10*512,
Expand All @@ -420,15 +420,16 @@ def get_attention(self):
labelled (bool, optional): [description]. Defaults to True.
"""
self.model.eval()
self.attention_ngh1, self.attention_ngh2 = self.model.module.inference_attention(
self.attention = self.model.module.inference_attention(
self.g,
self.model.device,
5*512,
0,
nodes=self.g.nodes(),
buffer_device=self.g.device)#.detach().numpy()
self.g.edata['attention1'] = self.attention_ngh1
self.g.edata['attention2'] = self.attention_ngh2
for e, a in enumerate(self.attention):
self.g.edata['attention{}'.format(e+1)] = a

self.save_graph()

def get_attention_nodes(self,nodes=None):
Expand All @@ -443,13 +444,13 @@ def get_attention_nodes(self,nodes=None):
labelled (bool, optional): [description]. Defaults to True.
"""
self.model.eval()
att1,att2 = self.model.module.inference_attention(self.g,
att = self.model.module.inference_attention(self.g,
self.model.device,
5*512,
0,
nodes=nodes,
buffer_device=self.g.device)#.detach().numpy()
return att1, att2
return att

def compute_distance_th(self,coords):
"""
Expand All @@ -465,8 +466,8 @@ def compute_distance_th(self,coords):

from scipy.spatial import cKDTree as KDTree
kdT = KDTree(coords)
d,i = kdT.query(coords,k=3)
d_th = np.percentile(d[:,-1],95)*self.distance_factor
d,i = kdT.query(coords,k=2)
d_th = np.percentile(d[:,-1],97)*self.distance_factor
logging.info('Chosen dist to connect molecules into a graph: {}'.format(d_th))
print('Chosen dist to connect molecules into a graph: {}'.format(d_th))
return d_th
Expand Down Expand Up @@ -565,9 +566,42 @@ def cluster(self, n_clusters=10):
[type]: [description]
"""
from sklearn.cluster import MiniBatchKMeans
import scanpy as sc
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

clusters = MiniBatchKMeans(n_clusters=n_clusters).fit_predict(self.latent_unlabelled)
self.g.ndata['CellularNgh'] = th.tensor(clusters)

'''logging.info('Latent embeddings generated for {} molecules'.format(self.latent_unlabelled.shape[0]))
random_sample_train = np.random.choice(
len(self.latent_unlabelled ),
np.min([len(self.latent_unlabelled ),250000]),
replace=False)
training_latents = self.latent_unlabelled[random_sample_train,:]
adata = sc.AnnData(X=training_latents.detach().numpy())
logging.info('Building neighbor graph for clustering...')
sc.pp.neighbors(adata, n_neighbors=15)
logging.info('Running Leiden clustering...')
sc.tl.leiden(adata, random_state=42, resolution=1)
logging.info('Leiden clustering done.')
clusters= adata.obs['leiden'].values
logging.info('Total of {} found'.format(len(np.unique(clusters))))
clf = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3))
clf.fit(training_latents, clusters)
clusters = clf.predict(self.latent_unlabelled).astype('int8')
clf_total = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3))
clf_total.fit(self.latent_unlabelled.detach().numpy(), clusters)
clusters = clf.predict(self.latent_unlabelled.detach().numpy()).astype('int8')
self.g.ndata['CellularNgh'] = th.tensor(clusters)'''


self.save_graph()
return clusters

Expand Down
Empty file modified FISHscale/graphNN/cluster_utils.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/graph_decoder.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/graph_pci.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/graph_utils.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/graphdata.py
100755 → 100644
Empty file.
Empty file modified FISHscale/graphNN/models.py
100755 → 100644
Empty file.
Loading

0 comments on commit caefc60

Please sign in to comment.