From 5e47b8bdb0e752341d549861f62db70c3d1730eb Mon Sep 17 00:00:00 2001 From: Carlos Loucera Date: Thu, 4 May 2023 23:09:05 +0200 Subject: [PATCH 1/4] Add cli option to enable check_add for shap. --- drexml/cli/cli.py | 13 +++++++++++++ drexml/cli/stab_explainer.py | 4 ++-- drexml/cli/stab_scorer.py | 2 +- drexml/cli/stab_trainer.py | 2 +- drexml/utils.py | 4 ++-- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/drexml/cli/cli.py b/drexml/cli/cli.py index cb9649af..bedfc631 100644 --- a/drexml/cli/cli.py +++ b/drexml/cli/cli.py @@ -87,6 +87,16 @@ ] +_check_add_option = [ + click.option( + "--add/--no-add", + is_flag=True, + default=True, + help="Check the additivity when computing the SHAP values.", + ) +] + + def copy_files(ctx, fnames): """Copy files from tmp to ml folder.""" for fname in fnames: @@ -155,6 +165,7 @@ def build_cmd(ctx): str(int(ctx["n_gpus"])), str(ctx["n_cpus"]), str(int(ctx["debug"])), + str(int(ctx["add"])), ctx["mode"], ] @@ -230,6 +241,7 @@ def stability(**kwargs): @main.command() @add_options(_debug_option) +@add_options(_check_add_option) @add_options(_n_iters_option) @add_options(_n_gpus_option) @add_options(_n_cpus_option) @@ -247,6 +259,7 @@ def explain(**kwargs): @main.command() @add_options(_debug_option) +@add_options(_check_add_option) @add_options(_n_iters_option) @add_options(_n_gpus_option) @add_options(_n_cpus_option) diff --git a/drexml/cli/stab_explainer.py b/drexml/cli/stab_explainer.py index ef1ec9e3..f5b77f2d 100644 --- a/drexml/cli/stab_explainer.py +++ b/drexml/cli/stab_explainer.py @@ -19,7 +19,7 @@ if __name__ == "__main__": import sys - data_folder, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv) + data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv) this_seed = 82 queue = multiprocessing.Queue() @@ -112,7 +112,7 @@ def runner(model, bkg, new, check_add, use_gpu): model=this_model, bkg=features_bkg, new=gb, - check_add=True, + check_add=add, use_gpu=gpu, ) for _, gb in features_val.groupby( diff --git a/drexml/cli/stab_scorer.py b/drexml/cli/stab_scorer.py index 8e48139e..8a6119f4 100644 --- a/drexml/cli/stab_scorer.py +++ b/drexml/cli/stab_scorer.py @@ -20,7 +20,7 @@ # client = Client('127.0.0.1:8786') # pylint: disable=unbalanced-tuple-unpacking - data_folder, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv) + data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv) model, stab_cv, features, targets = get_stab( data_folder, n_splits, n_cpus, debug, n_iters ) diff --git a/drexml/cli/stab_trainer.py b/drexml/cli/stab_trainer.py index 74cfbbc2..ac6c1b5e 100644 --- a/drexml/cli/stab_trainer.py +++ b/drexml/cli/stab_trainer.py @@ -14,7 +14,7 @@ # client = Client('127.0.0.1:8786') # pylint: disable=unbalanced-tuple-unpacking - data_path, n_iters, n_gpus, n_cpus, n_splits, debug = parse_stab(sys.argv) + data_path, n_iters, n_gpus, n_cpus, n_splits, debug, add = parse_stab(sys.argv) model, stab_cv, X, Y = get_stab(data_path, n_splits, n_cpus, debug, n_iters) for i, split in enumerate(stab_cv): diff --git a/drexml/utils.py b/drexml/utils.py index 09674e8d..7ef96b4c 100644 --- a/drexml/utils.py +++ b/drexml/utils.py @@ -67,7 +67,7 @@ def parse_stab(argv): bool Debug flag. """ - _, data_folder, n_iters, n_gpus, n_cpus, debug, mode = argv + _, data_folder, n_iters, n_gpus, n_cpus, debug, add, mode = argv n_iters = int(n_iters) data_folder = Path(data_folder) n_gpus = int(n_gpus) @@ -79,7 +79,7 @@ def parse_stab(argv): else: n_splits = 3 if debug else 100 - return data_folder, n_iters, n_gpus, n_cpus, n_splits, debug + return data_folder, n_iters, n_gpus, n_cpus, n_splits, debug, add def get_stab(data_folder, n_splits, n_cpus, debug, n_iters): From 43e57ac1b6ff27be0e9ef934d0e9e3aedf92dcea Mon Sep 17 00:00:00 2001 From: Carlos Loucera Date: Fri, 5 May 2023 06:45:41 +0200 Subject: [PATCH 2/4] Fix bug when dealing with integers as bools. --- drexml/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/drexml/datasets.py b/drexml/datasets.py index 36de4735..23340ee7 100644 --- a/drexml/datasets.py +++ b/drexml/datasets.py @@ -148,6 +148,7 @@ def get_disease_data(disease, debug): pathvals.columns = pathvals.columns.str.replace("-", ".").str.replace(" ", ".") circuits = fetch_file(disease, key="circuits", version="latest", debug=debug) circuits.index = circuits.index.str.replace("-", ".").str.replace(" ", ".") + circuits[circuits_column] = circuits[circuits_column].astype(bool) genes = fetch_file(disease, key="genes", version="latest", debug=debug) # gene_exp = gene_exp[genes.index[genes[genes_column]]] From 07e92e2d3753144d1ae76042ac99c70e020613da Mon Sep 17 00:00:00 2001 From: Carlos Loucera Date: Fri, 5 May 2023 08:37:23 +0200 Subject: [PATCH 3/4] Move preprocessing to specific functions. --- drexml/datasets.py | 89 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 26 deletions(-) diff --git a/drexml/datasets.py b/drexml/datasets.py index 23340ee7..e8db1c18 100644 --- a/drexml/datasets.py +++ b/drexml/datasets.py @@ -29,7 +29,7 @@ RECORD_ID = "7737166" -def fetch_file(disease, key, version="latest", debug=False): +def fetch_file(disease, key, env, version="latest", debug=False): """Retrieve data.""" print(f"Retrieving {key}") experiment_env_path = pathlib.Path(disease) @@ -52,7 +52,10 @@ def fetch_file(disease, key, version="latest", debug=False): data_path = experiment_env_path.parent path = data_path.joinpath(env[key]) - return load_df(path, key) + frame = load_df(path, key) + frame = preprocess_frame(frame, env, key) + + return frame def load_df(path, key=None): @@ -76,8 +79,6 @@ def load_df(path, key=None): try: # tsv, and compressed tsv res = pd.read_csv(path, sep="\t") - if "index" in res.columns: - res = res.set_index("index", drop=True) except (ParserError, KeyError, UnicodeDecodeError) as err: print("Error found while trying to load a TSV or compressed TSV.") print(err) @@ -94,14 +95,6 @@ def load_df(path, key=None): if res.shape[0] == 0: raise NotImplementedError("Format not implemented yet.") - if key is not None: - index_name_options = get_index_name_options(key) - - for name in index_name_options: - if name in res.columns: - res = res.set_index(name, drop=True) - res.index = res.index.astype(str) - return res @@ -115,6 +108,49 @@ def get_index_name_options(key): return ["index"] +def preprocess_frame(res, env, key): + + if key is not None: + index_name_options = get_index_name_options(key) + + for name in index_name_options: + if name in res.columns: + res = res.set_index(name, drop=True) + res.index = res.index.astype(str) + + if key == "gene_exp": + return preprocess_gexp(res) + elif key == "pathvals": + return preprocess_activities(res) + elif key == "circuits": + return preprocess_map(res, env["circuits_column"]) + elif key == "genes": + return preprocess_genes(res, env["genes_column"]) + + +def preprocess_gexp(frame): + frame.columns = frame.columns.str.replace("X", "") + return frame + + +def preprocess_activities(frame): + frame.columns = frame.columns.str.replace("-", ".").str.replace(" ", ".") + return frame + + +def preprocess_map(frame, circuits_column): + frame.index = frame.index.str.replace("-", ".").str.replace(" ", ".") + frame[circuits_column] = frame[circuits_column].astype(bool) + + return frame + + +def preprocess_genes(frame, genes_column): + frame = frame.loc[frame[genes_column]] + + return frame + + def get_disease_data(disease, debug): """Get data for a disease. @@ -141,24 +177,25 @@ def get_disease_data(disease, debug): genes_column = env["genes_column"] circuits_column = env["circuits_column"] - gene_exp = fetch_file(disease, key="gene_exp", version="latest", debug=debug) - gene_exp.columns = gene_exp.columns.str.replace("X", "") - - pathvals = fetch_file(disease, key="pathvals", version="latest", debug=debug) - pathvals.columns = pathvals.columns.str.replace("-", ".").str.replace(" ", ".") - circuits = fetch_file(disease, key="circuits", version="latest", debug=debug) - circuits.index = circuits.index.str.replace("-", ".").str.replace(" ", ".") - circuits[circuits_column] = circuits[circuits_column].astype(bool) - genes = fetch_file(disease, key="genes", version="latest", debug=debug) + gene_exp = fetch_file( + disease, key="gene_exp", env=env, version="latest", debug=debug + ) + pathvals = fetch_file( + disease, key="pathvals", env=env, version="latest", debug=debug + ) + circuits = fetch_file( + disease, key="circuits", env=env, version="latest", debug=debug + ) + genes = fetch_file(disease, key="genes", env=env, version="latest", debug=debug) # gene_exp = gene_exp[genes.index[genes[genes_column]]] - getx_entrez = gene_exp.columns - this_genes = genes.index[genes[genes_column]] - if this_genes.difference(getx_entrez).size > 0: - print(f"# genes not present in GTEx: {this_genes.difference(getx_entrez).size}") + gtex_entrez = gene_exp.columns + gene_diff = genes.index.difference(gtex_entrez).size + if gene_diff > 0: + print(f"# genes not present in GTEx: {gene_diff}") - usable_genes = this_genes.intersection(getx_entrez) + usable_genes = genes.index.intersection(gtex_entrez) gene_exp = gene_exp[usable_genes] pathvals = pathvals[circuits.index[circuits[circuits_column]]] From e6db18fc92e316daf0eaefabda003531f2b7bf8a Mon Sep 17 00:00:00 2001 From: Carlos Loucera Date: Fri, 5 May 2023 12:24:45 +0200 Subject: [PATCH 4/4] Convert to absolute paths. --- drexml/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drexml/utils.py b/drexml/utils.py index 7ef96b4c..f28eb5e8 100644 --- a/drexml/utils.py +++ b/drexml/utils.py @@ -147,7 +147,7 @@ def get_out_path(disease): The desired path. """ - env_possible = Path(disease) + env_possible = Path(disease).absolute() if env_possible.exists() and (env_possible.suffix == ".env"): print(f"Working with experiment {env_possible.parent.name}")