diff --git a/.gitignore b/.gitignore index ed1fc807..62f82a84 100644 --- a/.gitignore +++ b/.gitignore @@ -282,3 +282,4 @@ err2.txt out2.txt *.err *.out +results_July2021_gpu diff --git a/src/ml_plots.py b/src/ml_plots.py index 591e7381..a47e0216 100644 --- a/src/ml_plots.py +++ b/src/ml_plots.py @@ -127,6 +127,7 @@ def save_median_df(rel_cv, cut, symbol_dict, results_path): def plot_relevance_distribution(rel_cv, cut, symb_dict, pdir, extensions=exts): + sns.set_theme(context="paper", style="whitegrid") query_top = rel_cv.columns[rel_cv.median() > cut] to_plot = rel_cv.loc[:, query_top].copy() to_plot = to_plot.loc[:, to_plot.median().sort_values(ascending=False).index] @@ -136,10 +137,21 @@ def plot_relevance_distribution(rel_cv, cut, symb_dict, pdir, extensions=exts): n_genes = to_plot.shape[1] # sns.set_context("poster") - plt.figure() + fig, ax = plt.subplots(1, 1) + #def set_size(fig): + #fig.set_size_inches(6, 3) + #plt.tight_layout() figsize_x = 1.0 if n_genes < 90 else 1.2 - ax = to_plot.plot(kind="box", figsize=(16 * figsize_x, 9), rot=90) - ax.set_ylabel("Relevance") + to_plot = to_plot.melt() + sns.boxplot(x="variable", y="value", data=to_plot, color="lightgray", ax=ax) + #ax = to_plot.plot(kind="box", rot=90, color="lightgray") + ax.set_ylabel("Relevance", fontsize=12) + ax.set_xlabel("Gene", fontsize=12) + ax.set_xticklabels(ax.get_xticklabels(), rotation=90) + #set_size(fig) + sns.despine() + #g.fig.set_size_inches(6, 3) + fig.set_size_inches(6, 3) plt.tight_layout() name = "cv_relevance_distribution" @@ -151,7 +163,7 @@ def plot_relevance_distribution(rel_cv, cut, symb_dict, pdir, extensions=exts): for ext in extensions: fname = f"{name}.{ext}" fpath = pdir.joinpath(fname) - plt.savefig(fpath, dpi=300, bbox_inches="tight", pad_inches=0) + fig.savefig(fpath, dpi=300, bbox_inches="tight", pad_inches=0.05) plt.close()