Skip to content

Commit

Permalink
Merge pull request #63 from loucerac/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
loucerac authored Aug 15, 2023
2 parents 778093d + cd11a6d commit 43d9e8d
Show file tree
Hide file tree
Showing 23 changed files with 4,520 additions and 96 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[![DOI](https://zenodo.org/badge/362395439.svg)](https://zenodo.org/badge/latestdoi/362395439)
[![DOI](https://zenodo.org/badge/362395439.svg)](https://zenodo.org/badge/latestdoi/362395439) [![PyPI version](https://badge.fury.io/py/drexml.svg)](https://badge.fury.io/py/drexml) [![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm.fming.dev)

# Drug REpurposing using eXplainable Machine Learning and Mechanistic Models of signal transduction

Expand Down
46 changes: 42 additions & 4 deletions drexml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

from drexml.datasets import get_data
from drexml.plotting import plot_metrics
from drexml.plotting import RepurposingResult
from drexml.utils import (
check_gputree_availability,
get_number_cuda_devices,
Expand Down Expand Up @@ -314,13 +314,51 @@ def run(ctx, **kwargs):


@main.command()
@click.argument("stab-path", type=click.Path(exists=True))
@click.argument("sel-path", type=click.Path(exists=True))
@click.argument("score-path", type=click.Path(exists=True))
@click.argument("stability-path", type=click.Path(exists=True))
@click.argument("output-folder", type=click.Path(exists=True))
@click.option(
"--gene",
type=str,
help="Gene (KDT) Symbol to plot its repurposing profile.",
)
@click.version_option(get_version())
@click.pass_context
def plot(ctx, stab_path):
def plot(ctx, sel_path, score_path, stability_path, output_folder, gene):
"""Plot the stability results"""

plot_metrics(stab_path)
results = RepurposingResult(
sel_mat=sel_path, score_mat=score_path, stab_mat=stability_path
)

# Tests already covered in plotting.
if gene: # pragma: no cover
try:
results.plot_gene_profile(gene=gene, output_folder=output_folder)
except KeyError as kerr:
print(kerr)
click.echo(f"Gene {gene} not in relevance matrix.")
except Exception as e:
print(e)

else:
try:
results.plot_metrics(output_folder=output_folder)
except Exception as e: # pragma: no cover
print(e)
click.echo("skipping metrics plot.")

for use_filter in [True, False]:
try:
results.plot_relevance_heatmap(
remove_unstable=use_filter, output_folder=output_folder
)
except Exception as e: # pragma: no cover
print(e)
click.echo(
f"skipping relevance heatmap for filter set to: {use_filter}"
)


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions drexml/cli/stab_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def runner(model, bkg, new, check_add, use_gpu):
print(f"Shap summary results saved to: {shap_summary_fpath}")

shap_summary_renamed = convert_names(
shap_relevances.set_index(shap_relevances.columns[0]),
shap_relevances,
["circuits", "genes"],
axis=[0, 1],
)
Expand All @@ -183,9 +183,7 @@ def runner(model, bkg, new, check_add, use_gpu):
(filt_i * 1).to_csv(fs_fpath, sep="\t")
print(f"Shap selection results saved to: {fs_fpath}")

fs_renamed = convert_names(
filt_i.set_index(filt_i.columns[0]), ["circuits", "genes"], axis=[0, 1]
)
fs_renamed = convert_names(filt_i, ["circuits", "genes"], axis=[0, 1])
fs_renamed.to_csv(
fs_fpath.absolute().parent.joinpath(f"{fs_fpath.stem}_symbol.tsv"),
sep="\t",
Expand Down
1 change: 1 addition & 0 deletions drexml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

DEFAULT_DICT = {
"seed_genes": None,
"disease_id": None,
"use_physio": "true",
"gene_exp": None,
"gene_exp_zenodo": False,
Expand Down
81 changes: 78 additions & 3 deletions drexml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib

import pandas as pd
import pystow
from pandas.errors import ParserError
from requests.exceptions import ConnectTimeout
from zenodo_client import Zenodo
Expand All @@ -14,7 +15,77 @@
RECORD_ID = "6020480"


def load_disgenet():
"""Download if necessary and load the Disgenet curated list of gene-disease
associations.
Returns
-------
pd.DataFrame
Disgenet curated dataset of gene-disease associations.
"""

url = "/".join(
[
"https:/",
"www.disgenet.org",
"static",
"disgenet_ap1",
"files",
"downloads",
"curated_gene_disease_associations.tsv.gz",
]
)

disgenet: pd.DataFrame = pystow.ensure_csv(
"drexml", "datasets", url=url, read_csv_kwargs={"sep": "\t"}
)

disgenet = disgenet.rename(
columns={
"geneId": "entrez_id",
"diseaseId": "disease_id",
"diseaseName": "disease_name",
"score": "dga_score",
}
).loc[:, ["disease_name", "disease_id", "entrez_id", "dga_score"]]

return disgenet


def get_gda(disease_id, k_top=40):
"""Retrieve the list of genes associated to a disese according to the Disgenet
curated list of gene-disease associations.
Parameters
----------
disease_id : str
Disease ID.
k_top: int
Retrieve at most k_top genes based on the GDA score.
Returns
-------
list
List of gene IDs.
"""
disgenet = load_disgenet()
disgenet = disgenet.loc[disgenet["disease_id"] == disease_id]
disgenet = disgenet.nlargest(k_top, "dga_score")

return disgenet.entrez_id.astype(str).unique().tolist()


def load_physiological_circuits():
"""Load the list of physiological circuits.
Returns
-------
list
List of physiological circuit IDs.
"""
fpath = get_resource_path("circuit_names.tsv.gz")
circuit_names = pd.read_csv(fpath, sep="\t").set_index("circuit_id")
circuit_names.index = circuit_names.index.str.replace("-", ".").str.replace(
Expand Down Expand Up @@ -174,9 +245,12 @@ def preprocess_frame(res, env, key):
elif key == "pathvals":
return preprocess_activities(res)
elif key == "circuits":
return preprocess_map(
res, env["seed_genes"], env["circuits_column"], env["use_physio"]
)
gene_list = []
if env["seed_genes"]:
gene_list += env["seed_genes"]
if env["disease_id"]:
gene_list += [str(gene) for gene in get_gda(env["disease_id"])]
return preprocess_map(res, gene_list, env["circuits_column"], env["use_physio"])
elif key == "genes":
return preprocess_genes(res, env["genes_column"])

Expand Down Expand Up @@ -279,6 +353,7 @@ def preprocess_map(frame, disease_seed_genes, circuits_column, use_physio):
"""
frame.index = frame.index.str.replace("-", ".").str.replace(" ", ".")
if disease_seed_genes:
print(disease_seed_genes)
disease_seed_genes = frame.columns.intersection(disease_seed_genes)
circuits = frame.index[frame[disease_seed_genes].any(axis=1)].tolist()
else:
Expand Down
Loading

0 comments on commit 43d9e8d

Please sign in to comment.