diff --git a/drexml/datasets.py b/drexml/datasets.py index 23340ee7..e8db1c18 100644 --- a/drexml/datasets.py +++ b/drexml/datasets.py @@ -29,7 +29,7 @@ RECORD_ID = "7737166" -def fetch_file(disease, key, version="latest", debug=False): +def fetch_file(disease, key, env, version="latest", debug=False): """Retrieve data.""" print(f"Retrieving {key}") experiment_env_path = pathlib.Path(disease) @@ -52,7 +52,10 @@ def fetch_file(disease, key, version="latest", debug=False): data_path = experiment_env_path.parent path = data_path.joinpath(env[key]) - return load_df(path, key) + frame = load_df(path, key) + frame = preprocess_frame(frame, env, key) + + return frame def load_df(path, key=None): @@ -76,8 +79,6 @@ def load_df(path, key=None): try: # tsv, and compressed tsv res = pd.read_csv(path, sep="\t") - if "index" in res.columns: - res = res.set_index("index", drop=True) except (ParserError, KeyError, UnicodeDecodeError) as err: print("Error found while trying to load a TSV or compressed TSV.") print(err) @@ -94,14 +95,6 @@ def load_df(path, key=None): if res.shape[0] == 0: raise NotImplementedError("Format not implemented yet.") - if key is not None: - index_name_options = get_index_name_options(key) - - for name in index_name_options: - if name in res.columns: - res = res.set_index(name, drop=True) - res.index = res.index.astype(str) - return res @@ -115,6 +108,49 @@ def get_index_name_options(key): return ["index"] +def preprocess_frame(res, env, key): + + if key is not None: + index_name_options = get_index_name_options(key) + + for name in index_name_options: + if name in res.columns: + res = res.set_index(name, drop=True) + res.index = res.index.astype(str) + + if key == "gene_exp": + return preprocess_gexp(res) + elif key == "pathvals": + return preprocess_activities(res) + elif key == "circuits": + return preprocess_map(res, env["circuits_column"]) + elif key == "genes": + return preprocess_genes(res, env["genes_column"]) + + +def preprocess_gexp(frame): + frame.columns = frame.columns.str.replace("X", "") + return frame + + +def preprocess_activities(frame): + frame.columns = frame.columns.str.replace("-", ".").str.replace(" ", ".") + return frame + + +def preprocess_map(frame, circuits_column): + frame.index = frame.index.str.replace("-", ".").str.replace(" ", ".") + frame[circuits_column] = frame[circuits_column].astype(bool) + + return frame + + +def preprocess_genes(frame, genes_column): + frame = frame.loc[frame[genes_column]] + + return frame + + def get_disease_data(disease, debug): """Get data for a disease. @@ -141,24 +177,25 @@ def get_disease_data(disease, debug): genes_column = env["genes_column"] circuits_column = env["circuits_column"] - gene_exp = fetch_file(disease, key="gene_exp", version="latest", debug=debug) - gene_exp.columns = gene_exp.columns.str.replace("X", "") - - pathvals = fetch_file(disease, key="pathvals", version="latest", debug=debug) - pathvals.columns = pathvals.columns.str.replace("-", ".").str.replace(" ", ".") - circuits = fetch_file(disease, key="circuits", version="latest", debug=debug) - circuits.index = circuits.index.str.replace("-", ".").str.replace(" ", ".") - circuits[circuits_column] = circuits[circuits_column].astype(bool) - genes = fetch_file(disease, key="genes", version="latest", debug=debug) + gene_exp = fetch_file( + disease, key="gene_exp", env=env, version="latest", debug=debug + ) + pathvals = fetch_file( + disease, key="pathvals", env=env, version="latest", debug=debug + ) + circuits = fetch_file( + disease, key="circuits", env=env, version="latest", debug=debug + ) + genes = fetch_file(disease, key="genes", env=env, version="latest", debug=debug) # gene_exp = gene_exp[genes.index[genes[genes_column]]] - getx_entrez = gene_exp.columns - this_genes = genes.index[genes[genes_column]] - if this_genes.difference(getx_entrez).size > 0: - print(f"# genes not present in GTEx: {this_genes.difference(getx_entrez).size}") + gtex_entrez = gene_exp.columns + gene_diff = genes.index.difference(gtex_entrez).size + if gene_diff > 0: + print(f"# genes not present in GTEx: {gene_diff}") - usable_genes = this_genes.intersection(getx_entrez) + usable_genes = genes.index.intersection(gtex_entrez) gene_exp = gene_exp[usable_genes] pathvals = pathvals[circuits.index[circuits[circuits_column]]]