-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 3b94bbc
Showing
16 changed files
with
2,110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.csv filter=lfs diff=lfs merge=lfs -text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
data/ | ||
data_viz/ | ||
hammond_models/ | ||
saved_models/ | ||
archive/ | ||
results/ | ||
result_logs/ | ||
beeline_generated.zip | ||
__pycache__/ | ||
.ipynb_checkpoints | ||
data_viz_latest/ | ||
bash_scripts.ipynb | ||
hammond_export_net.ipynb | ||
final_report.ipynb | ||
netrexcf.ipynb | ||
gene_feature_exp.ipynb | ||
hammond_viz_old.ipynb | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# GRN-VAE | ||
|
||
This repository include code and documentation for GRN-VAE, a stablized SEM style Variational autoencoder for gene regulatory network inference. | ||
|
||
The pre-print of this paper could be found [here](https://bcb.cs.tufts.edu/GRN-VAE/GRNVAE_ISMB_submission.pdf) | ||
|
||
# Getting Started with GRN-VAE | ||
|
||
This document provides an end-to-end demonstration on how to infer GRN with our implementation of GRN-VAE. | ||
|
||
|
||
```python | ||
import numpy as np | ||
from data import load_beeline | ||
from logger import LightLogger | ||
from runner import runGRNVAE, runGRNVAE_ensemble, DEFAULT_GRNVAE_CONFIGS | ||
from runner import runDeepSEM, runDeepSEM_ensemble, DEFAULT_DEEPSEM_CONFIGS | ||
from evaluate import extract_edges, get_metrics | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
``` | ||
|
||
## Model Configurations | ||
|
||
First you need to define some configs for running the model. Here we provide a default set of parameters in `runner` called `DEFAULT_GRNVAE_CONFIGS`. It comes in the forms of a standard python dictionary so it's very easy to modify as needed. | ||
|
||
The three key concepts proposed in the GRN-VAE paper are controlled by the following parameters. | ||
|
||
- `delayed_steps_on_sparse`: Number of delayed steps on introducing the sparse loss. | ||
- `dropout_augmentation`: The proportion of data that will be randomly masked as dropout in each traing step. | ||
- `train_on_non_zero`: Whether to train the model on non-zero expression data | ||
|
||
## Data loading | ||
[BEELINE benchmarks](https://github.com/Murali-group/Beeline) could be loaded by the `load_beeline` function, where you specify where to look for data and which benchmark to load. If it's the first time, this function will download the files automatically. | ||
|
||
The `data` object exported by `load_beeline` is an [annData](https://anndata.readthedocs.io/en/stable/generated/anndata.AnnData.html#anndata.AnnData) object read by [scanpy](https://scanpy.readthedocs.io/en/stable/). The `ground_truth` object includes ground truth edges based on the BEELINE benchmark but it's not required for network inference. | ||
|
||
When you use GRN-VAE on a real world data to discover noval regulatory relationship, here are a few tips on preparing your data: | ||
|
||
- You can read in data in any formats but make sure your data has genes in the column/var and cells in the rows/obs. Transpose your data if it's necessary. | ||
- Find out the most variable genes. Unlike many traditional algorithm, GRN-VAE has the capacity to run on large amount of data. Therefore you can set the number of variable genes very high. As described in the paper, we used 5,000 for our Hammond experiment. The only reason why we need this gene filter is to help converge the model. | ||
- Normalize your data. A simple log transformation is good enough. | ||
|
||
|
||
```python | ||
# Load data from a BEELINE benchmark | ||
data, ground_truth = load_beeline( | ||
data_dir='data', | ||
benchmark_data='hESC', | ||
benchmark_setting='500_STRING' | ||
) | ||
``` | ||
|
||
## Model Training | ||
|
||
Model training is simple with the `runGRNVAE` function. As said above, if ground truth is not available, just set `ground_truth` to be `None`. | ||
|
||
|
||
```python | ||
logger = LightLogger() | ||
# runGRNVAE initializes and trains a GRNVAE model with the configs specified. | ||
vae, adjs = runGRNVAE( | ||
data.X, DEFAULT_GRNVAE_CONFIGS, ground_truth=ground_truth, logger=logger) | ||
``` | ||
|
||
100%|██████████| 120/120 [00:33<00:00, 3.63it/s] | ||
|
||
|
||
The learned adjacency matrix could be obtained by the `get_adj()` method. For BEELINE benchmarks, you can get the performance metrics of this run using the `get_metrics` function. | ||
|
||
|
||
```python | ||
A = vae.get_adj() | ||
get_metrics(A, ground_truth) | ||
``` | ||
|
||
|
||
|
||
|
||
{'AUPR': 0.05958849485016752, | ||
'AUPRR': 2.4774368161948437, | ||
'EP': 504, | ||
'EPR': 4.922288423345506} | ||
|
||
We also provide our own implementation of [DeepSEM](https://www.nature.com/articles/s43588-021-00099-8). You can execute DeepSEM and the ensemble version of it using `runDeepSEM` and `runDeepSEM_ensemble`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=grnvae | ||
#SBATCH -p preempt | ||
#SBATCH -n 1 | ||
#SBATCH --gres=gpu:a100:1 | ||
#SBATCH --mem=6g | ||
#SBATCH --time=0-6:00:00 | ||
|
||
# >>> conda initialize >>> | ||
# !! Contents within this block are managed by 'conda init' !! | ||
__conda_setup="$('/cluster/tufts/slonimlab/hzhu07/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" | ||
if [ $? -eq 0 ]; then | ||
eval "$__conda_setup" | ||
else | ||
if [ -f "/cluster/tufts/slonimlab/hzhu07/miniconda3/etc/profile.d/conda.sh" ]; then | ||
. "/cluster/tufts/slonimlab/hzhu07/miniconda3/etc/profile.d/conda.sh" | ||
else | ||
export PATH="/cluster/tufts/slonimlab/hzhu07/miniconda3/bin:$PATH" | ||
fi | ||
fi | ||
unset __conda_setup | ||
# <<< conda initialize <<< | ||
|
||
cd /cluster/tufts/slonimlab/hzhu07/grnvae | ||
conda activate grn | ||
python exp_beeline.py "$1" "$2" "$3" "$4" "$5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#!/bin/bash | ||
#SBATCH --job-name=grnvae | ||
#SBATCH -p preempt | ||
#SBATCH -n 2 | ||
#SBATCH --gres=gpu:a100:1 | ||
#SBATCH --mem=24g | ||
#SBATCH --time=0-12:00:00 | ||
|
||
# >>> conda initialize >>> | ||
# !! Contents within this block are managed by 'conda init' !! | ||
__conda_setup="$('/cluster/tufts/slonimlab/hzhu07/miniconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)" | ||
if [ $? -eq 0 ]; then | ||
eval "$__conda_setup" | ||
else | ||
if [ -f "/cluster/tufts/slonimlab/hzhu07/miniconda3/etc/profile.d/conda.sh" ]; then | ||
. "/cluster/tufts/slonimlab/hzhu07/miniconda3/etc/profile.d/conda.sh" | ||
else | ||
export PATH="/cluster/tufts/slonimlab/hzhu07/miniconda3/bin:$PATH" | ||
fi | ||
fi | ||
unset __conda_setup | ||
# <<< conda initialize <<< | ||
|
||
cd /cluster/tufts/slonimlab/hzhu07/grnvae | ||
conda activate grn | ||
python exp_hammond.py "$1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import numpy as np | ||
import scanpy as sc | ||
import pandas as pd | ||
import urllib | ||
from tqdm import tqdm | ||
import zipfile | ||
import os | ||
|
||
# Read ground truth | ||
def load_beeline_ground_truth(data_dir, gene_names): | ||
n_gene = len(gene_names) | ||
ground_truth = pd.read_csv(f'{data_dir}/label.csv') | ||
TF = set(ground_truth['Gene1']) | ||
All_gene = set(ground_truth['Gene1']) | set(ground_truth['Gene2']) | ||
|
||
evaluate_mask = np.zeros([n_gene, n_gene]) | ||
TF_mask = np.zeros([n_gene, n_gene]) | ||
for i, item in enumerate(gene_names): | ||
for j, item2 in enumerate(gene_names): | ||
if i == j: | ||
continue | ||
if item in TF and item2 in All_gene: | ||
evaluate_mask[i, j] = 1 | ||
if item in TF: | ||
TF_mask[i, j] = 1 | ||
|
||
truth_df = pd.DataFrame(np.zeros([n_gene, n_gene]), | ||
index=gene_names, columns=gene_names) | ||
for i in range(ground_truth.shape[0]): | ||
truth_df.loc[ground_truth.iloc[i, 0], ground_truth.iloc[i, 1]] = 1 | ||
A_truth = truth_df.values | ||
|
||
idx_source, idx_target = np.where(A_truth) | ||
truth_edges = set(zip(idx_source, idx_target)) | ||
|
||
eval_flat_mask = (evaluate_mask.flatten() != 0) | ||
y_true = A_truth.flatten()[eval_flat_mask] | ||
|
||
return eval_flat_mask, y_true, truth_edges | ||
|
||
def load_beeline(data_dir, benchmark_data='hESC', | ||
benchmark_setting='500_STRING'): | ||
''' Load BEELINE | ||
Load BEELINE data into memory (download if necessary). | ||
Parameters | ||
---------- | ||
data_dir: str | ||
Root folder where the BEELINE data is/will be located. | ||
benchmark_data: str | ||
Benchmark datasets. Choose among `hESC`, `hHep`, `mDC`, | ||
`mESC`, `mHSC`, `mHSC-GM`, and `mHSC-L`. | ||
benchmark_setting: str | ||
Benchmark settings. Choose among `500_STRING`, | ||
`1000_STRING`, `500_Non-ChIP`, `1000_Non-ChIP`, | ||
`500_ChIP-seq`, `1000_ChIP-seq`, `500_lofgof`, | ||
and `1000_lofgof`. If either of the `lofgof` settings | ||
is choosed, only `mESC` data is available. | ||
Returns | ||
------- | ||
tuple | ||
First element is a scanpy data with cells on rows and | ||
genes on columns. Second element is the corresponding | ||
BEELINE ground truth data | ||
''' | ||
if not os.path.exists(data_dir): | ||
os.mkdir(data_dir) | ||
if not os.path.exists(f'{data_dir}/BEELINE/'): | ||
download_beeline(data_dir) | ||
data_dir = f'{data_dir}/BEELINE/{benchmark_setting}_{benchmark_data}' | ||
data = sc.read(f'{data_dir}/data.csv') | ||
# We do need to transpose the data to have cells on rows and genes on columns | ||
data = data.transpose() | ||
ground_truth = load_beeline_ground_truth(data_dir, data.var_names) | ||
return data, ground_truth | ||
|
||
def download_beeline(save_dir, remove_zip=True): | ||
if not os.path.exists(save_dir): | ||
raise Exception("save_dir does not exist") | ||
zip_path = os.path.join(save_dir, 'BEELINE.zip') | ||
download_file('https://bcb.cs.tufts.edu/GRN-VAE/BEELINE.zip', | ||
zip_path) | ||
with zipfile.ZipFile(zip_path,"r") as zip_ref: | ||
for file in tqdm(desc='Extracting', iterable=zip_ref.namelist(), | ||
total=len(zip_ref.namelist())): | ||
zip_ref.extract(member=file, path=save_dir) | ||
if remove_zip: | ||
os.remove(zip_path) | ||
|
||
# Modified from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 | ||
def download_file(url, file_path, chunk_size=1024): | ||
req = urllib.request.urlopen(url) | ||
with open(file_path, 'wb') as f, tqdm( | ||
desc=f'Downloading {file_path}', total=req.length, unit='iB', | ||
unit_scale=True, unit_divisor=1024 | ||
) as bar: | ||
for _ in range(req.length // chunk_size + 1): | ||
chunk = req.read(chunk_size) | ||
if not chunk: break | ||
size = f.write(chunk) | ||
bar.update(size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics import average_precision_score | ||
|
||
# Modified from | ||
# https://github.com/HantaoShu/DeepSEM/blob/master/src/utils.py | ||
|
||
def get_metrics(A, ground_truth): | ||
''' Calculate Metrics including AUPR, AUPRR, EP, and EPR | ||
Calculate EPR given predicted adjacency matrix and BEELINE | ||
ground truth | ||
Parameters | ||
---------- | ||
A: numpy.array | ||
Predicted adjacency matrix. Expected size is |g| x |g|. | ||
ground_truth: tuple | ||
BEELINE ground truth object exported by | ||
data.load_beeline_ground_truth. The first element of this | ||
tuple is eval_flat_mask, the boolean mask on the flatten | ||
adjacency matrix to identify TFs and target genes. The | ||
second element is the lable values y_true after flatten. | ||
Returns | ||
------- | ||
tuple | ||
A tuple with AUPR, AUPR ratio, EP (in counts), and EPR | ||
''' | ||
eval_flat_mask, y_true, _ = ground_truth | ||
y_pred = np.abs(A.flatten()[eval_flat_mask]) | ||
|
||
AUPR = average_precision_score(y_true, y_pred) | ||
AUPRR = AUPR / np.mean(y_true) | ||
|
||
num_truth_edge = int(y_true.sum()) | ||
cutoff = np.partition(y_pred, -num_truth_edge)[-num_truth_edge] | ||
y_above_cutoff = y_pred > cutoff | ||
EP = int(np.sum(y_true[y_above_cutoff])) | ||
EPR = 1. * EP / ((num_truth_edge ** 2) / np.sum(eval_flat_mask)) | ||
|
||
return {'AUPR': AUPR, 'AUPRR': AUPRR, | ||
'EP': EP, 'EPR': EPR} | ||
|
||
# def top_k_filter(A, evaluate_mask, topk): | ||
# A= abs(A) | ||
# if evaluate_mask is None: | ||
# evaluate_mask = np.ones_like(A) - np.eye(len(A)) | ||
# A = A * evaluate_mask | ||
# A_val = list(np.sort(abs(A.reshape(-1, 1)), 0)[:, 0]) | ||
# A_val.reverse() | ||
# cutoff_all = A_val[topk] | ||
# A_above_cutoff = np.zeros_like(A) | ||
# A_above_cutoff[abs(A) > cutoff_all] = 1 | ||
# return A_above_cutoff | ||
|
||
# def get_epr(A, ground_truth): | ||
# ''' Calculate EPR | ||
|
||
# Calculate EPR given predicted adjacency matrix and BEELINE | ||
# ground truth | ||
|
||
# Parameters | ||
# ---------- | ||
# A: numpy.array | ||
# Predicted adjacency matrix. Expected size is |g| x |g|. | ||
# ground_truth: tuple | ||
# BEELINE ground truth object exported by | ||
# data.load_beeline_ground_truth. It's a tuple with the | ||
# first element being truth_edges and second element being | ||
# evaluate_mask. | ||
|
||
# Returns | ||
# ------- | ||
# tuple | ||
# A tuple with calculated EP (in counts) and EPR | ||
# ''' | ||
# eval_flat_mask, y_true, truth_edges, evaluate_mask = ground_truth | ||
# num_nodes = A.shape[0] | ||
# num_truth_edges = len(truth_edges) | ||
# A_above_cutoff = top_k_filter(A, evaluate_mask, num_truth_edges) | ||
# idx_source, idx_target = np.where(A_above_cutoff) | ||
# A_edges = set(zip(idx_source, idx_target)) | ||
# overlap_A = A_edges.intersection(truth_edges) | ||
# EP = len(overlap_A) | ||
# EPR = 1. * EP / ((num_truth_edges ** 2) / np.sum(evaluate_mask)) | ||
# return EP, EPR | ||
|
||
def extract_edges(A, gene_names=None, TFmask=None, threshold=0.0): | ||
'''Extract predicted edges | ||
Extract edges from the predicted adjacency matrix | ||
Parameters | ||
---------- | ||
A: numpy.array | ||
Predicted adjacency matrix. Expected size is |g| x |g|. | ||
gene_names: None, list or numpy.array | ||
(Optional) List of Gene Names. Usually accessible in the var_names | ||
field of scanpy data. | ||
TFmask: numpy.array | ||
A masking matrix indicating the position of TFs. Expected | ||
size is |g| x |g|. | ||
Returns | ||
------- | ||
pandas.DataFrame | ||
A DataFrame including all the predicted links with predicted | ||
link strength. | ||
''' | ||
num_nodes = A.shape[0] | ||
mat_indicator_all = np.zeros([num_nodes, num_nodes]) | ||
if TFmask is not None: | ||
A_masked = A * TFmask | ||
else: | ||
A_masked = A | ||
mat_indicator_all[abs(A_masked) > threshold] = 1 | ||
idx_source, idx_target = np.where(mat_indicator_all) | ||
if gene_names is None: | ||
source_lbl = idx_source | ||
target_lbl = idx_target | ||
else: | ||
source_lbl = gene_names[idx_source] | ||
target_lbl = gene_names[idx_target] | ||
edges_df = pd.DataFrame( | ||
{'Source': source_lbl, 'Target': target_lbl, | ||
'EdgeWeight': (A[idx_source, idx_target]), | ||
'AbsEdgeWeight': (np.abs(A[idx_source, idx_target])) | ||
}) | ||
edges_df = edges_df.sort_values('AbsEdgeWeight', ascending=False) | ||
|
||
return edges_df.reset_index(drop=True) |
Oops, something went wrong.