diff --git a/environment.yml b/environment.yml index 12041fc40..a335db7ba 100644 --- a/environment.yml +++ b/environment.yml @@ -17,7 +17,7 @@ dependencies: - matplotlib - python-kaleido - plotly - - seaborn + - seaborn<0.13 - pip # ML - pytorch=1 #=1.13.1=py3.8_cuda11.7_cudnn8_0 diff --git a/project/01_0_split_data.ipynb b/project/01_0_split_data.ipynb index e812b1960..35fde68b8 100644 --- a/project/01_0_split_data.ipynb +++ b/project/01_0_split_data.ipynb @@ -15,25 +15,24 @@ "metadata": {}, "outputs": [], "source": [ + "import logging\n", "from functools import partial\n", "from pathlib import Path\n", - "import logging\n", - "from typing import Union, List\n", + "from typing import List, Union\n", "\n", - "from IPython.display import display\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.model_selection import train_test_split\n", "import plotly.express as px\n", + "from IPython.display import display\n", + "from sklearn.model_selection import train_test_split\n", "\n", "import vaep\n", + "import vaep.io.load\n", + "from vaep.analyzers import analyzers\n", "from vaep.io.datasplits import DataSplits\n", "from vaep.sampling import feature_frequency\n", - "\n", - "from vaep.analyzers import analyzers\n", "from vaep.sklearn import get_PCA\n", - "import vaep.io.load\n", "\n", "logger = vaep.logging.setup_nb_logger()\n", "logger.info(\"Split data and make diagnostic plots\")\n", @@ -52,7 +51,7 @@ "\n", "\n", "pd.options.display.max_columns = 32\n", - "plt.rcParams['figure.figsize'] = [3, 2]\n", + "plt.rcParams['figure.figsize'] = [4, 2]\n", "\n", "vaep.plotting.make_large_descriptors(7)\n", "\n", @@ -82,6 +81,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "lines_to_next_cell": 2, "tags": [ "parameters" ] @@ -108,7 +108,8 @@ "# train, validation and test data splits\n", "frac_non_train: float = 0.1 # fraction of non training data (validation and test split)\n", "frac_mnar: float = 0.0 # fraction of missing not at random data, rest: missing completely at random\n", - "prop_sample_w_sim: float = 1.0 # proportion of samples with simulated missing values" + "prop_sample_w_sim: float = 1.0 # proportion of samples with simulated missing values\n", + "feat_name_display: str = None # display name for feature name (e.g. 'protein group')" ] }, { @@ -124,9 +125,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "lines_to_next_cell": 2 - }, + "metadata": {}, "outputs": [], "source": [ "args = vaep.nb.args_from_dict(args)\n", @@ -195,6 +194,11 @@ " )\n", "if args.column_names:\n", " df.columns.names = args.column_names\n", + "if args.feat_name_display is None:\n", + " args.overwrite_entry('feat_name_display', 'features')\n", + " if args.column_names:\n", + " args.overwrite_entry('feat_name_display', args.column_names[0])\n", + "\n", "\n", "if not df.index.name:\n", " logger.warning(\"No sample index name found, setting to 'Sample ID'\")\n", @@ -221,7 +225,7 @@ " .plot\n", " .box()\n", " )\n", - "ax.set_ylabel('number of observation across samples')" + "ax.set_ylabel('Frequency')" ] }, { @@ -557,7 +561,7 @@ "source": [ "group = 1\n", "ax = df.notna().sum(axis=1).hist()\n", - "ax.set_xlabel('features per eligable sample')\n", + "ax.set_xlabel(f'{args.feat_name_display.capitalize()} per eligable sample')\n", "ax.set_ylabel('observations')\n", "fname = args.out_figures / f'0_{group}_hist_features_per_sample'\n", "figures[fname.stem] = fname\n", @@ -576,7 +580,7 @@ "_new_labels = [l_.get_text().split(';')[0] for l_ in ax.get_xticklabels()]\n", "_ = ax.set_xticklabels(_new_labels, rotation=45,\n", " horizontalalignment='right')\n", - "ax.set_xlabel('feature prevalence')\n", + "ax.set_xlabel(f'{args.feat_name_display.capitalize()} prevalence')\n", "ax.set_ylabel('observations')\n", "fname = args.out_figures / f'0_{group}_feature_prevalence'\n", "figures[fname.stem] = fname\n", @@ -599,8 +603,9 @@ "min_max = vaep.plotting.data.min_max(df.stack())\n", "ax, bins = vaep.plotting.data.plot_histogram_intensities(\n", " df.stack(), min_max=min_max)\n", - "\n", + "ax.set_xlabel('Intensity binned')\n", "fname = args.out_figures / f'0_{group}_intensity_distribution_overall'\n", + "\n", "figures[fname.stem] = fname\n", "vaep.savefig(ax.get_figure(), fname)" ] @@ -614,6 +619,9 @@ "ax = vaep.plotting.data.plot_feat_median_over_prop_missing(\n", " data=df, type='scatter')\n", "fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_scatter'\n", + "ax.set_xlabel(\n", + " f'{args.feat_name_display.capitalize()} binned by their median intensity'\n", + " f' (N {args.feat_name_display})')\n", "figures[fname.stem] = fname\n", "vaep.savefig(ax.get_figure(), fname)" ] @@ -624,11 +632,17 @@ "metadata": {}, "outputs": [], "source": [ - "ax = vaep.plotting.data.plot_feat_median_over_prop_missing(\n", - " data=df, type='boxplot')\n", + "ax, _data_feat_median_over_prop_missing = vaep.plotting.data.plot_feat_median_over_prop_missing(\n", + " data=df, type='boxplot', return_plot_data=True)\n", "fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_boxplot'\n", + "ax.set_xlabel(\n", + " f'{args.feat_name_display.capitalize()} binned by their median intensity'\n", + " f' (N {args.feat_name_display})')\n", "figures[fname.stem] = fname\n", - "vaep.savefig(ax.get_figure(), fname)" + "vaep.savefig(ax.get_figure(), fname)\n", + "_data_feat_median_over_prop_missing.to_csv(fname.with_suffix('.csv'))\n", + "# _data_feat_median_over_prop_missing.to_excel(fname.with_suffix('.xlsx'))\n", + "del _data_feat_median_over_prop_missing" ] }, { @@ -644,7 +658,8 @@ "metadata": {}, "outputs": [], "source": [ - "sample_counts.name = 'identified features'" + "_feature_display_name = f'identified {args.feat_name_display}'\n", + "sample_counts.name = _feature_display_name" ] }, { @@ -718,12 +733,12 @@ "outputs": [], "source": [ "fig, ax = plt.subplots()\n", - "col_identified_feat = 'identified features'\n", + "col_identified_feat = _feature_display_name\n", "analyzers.plot_scatter(\n", " pcs[pcs_name],\n", " ax,\n", " pcs[col_identified_feat],\n", - " title=f'by {col_identified_feat}',\n", + " feat_name_display=args.feat_name_display,\n", " size=5,\n", ")\n", "fname = (args.out_figures\n", @@ -737,12 +752,23 @@ "execution_count": null, "metadata": {}, "outputs": [], + "source": [ + "# ! write principal components to excel (if needed)\n", + "# pcs.set_index([df.index.name])[[*pcs_name, col_identified_feat]].to_excel(fname.with_suffix('.xlsx'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c861197", + "metadata": {}, + "outputs": [], "source": [ "fig = px.scatter(\n", " pcs, x=pcs_name[0], y=pcs_name[1],\n", " hover_name=pcs_index_name,\n", " # hover_data=analysis.df_meta,\n", - " title=f'First two Principal Components of {args.M} features for {pcs.shape[0]} samples',\n", + " title=f'First two Principal Components of {args.M} {args.feat_name_display} for {pcs.shape[0]} samples',\n", " # color=pcs['Software Version'],\n", " color=col_identified_feat,\n", " template='none',\n", @@ -1295,7 +1321,8 @@ "ax.legend(_legend[:-1])\n", "if args.use_every_nth_xtick > 1:\n", " ax.set_xticks(ax.get_xticks()[::2])\n", - "fname = args.out_figures / f'0_{group}_test_over_train_split.pdf'\n", + "ax.set_xlabel('Intensity bins')\n", + "fname = args.out_figures / f'0_{group}_val_over_train_split.pdf'\n", "figures[fname.name] = fname\n", "vaep.savefig(ax.get_figure(), fname)" ] @@ -1324,6 +1351,18 @@ "vaep.savefig(ax.get_figure(), fname)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed941a81", + "metadata": {}, + "outputs": [], + "source": [ + "counts_per_bin = vaep.pandas.get_counts_per_bin(df=splits_df, bins=bins)\n", + "counts_per_bin.to_excel(fname.with_suffix('.xlsx'))\n", + "counts_per_bin" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1346,6 +1385,36 @@ "vaep.savefig(ax.get_figure(), fname)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "eede54fd", + "metadata": {}, + "outputs": [], + "source": [ + "# Save binned counts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e6e1f6b", + "metadata": {}, + "outputs": [], + "source": [ + "counts_per_bin = dict()\n", + "for col in splits_df.columns:\n", + " _series = (pd.cut(splits_df[col], bins=bins)\n", + " .to_frame()\n", + " .groupby(col)\n", + " .size())\n", + " _series.index.name = 'bin'\n", + " counts_per_bin[col] = _series\n", + "counts_per_bin = pd.DataFrame(counts_per_bin)\n", + "counts_per_bin.to_excel(fname.with_suffix('.xlsx'))\n", + "counts_per_bin" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1407,8 +1476,8 @@ "fig, ax = plt.subplots(figsize=(6, 2))\n", "s = 1\n", "s_axes = pd.DataFrame({'medians': medians,\n", - " 'validation split': splits.val_y.notna().sum(),\n", - " 'training split': splits.train_X.notna().sum()}\n", + " 'Validation split': splits.val_y.notna().sum(),\n", + " 'Training split': splits.train_X.notna().sum()}\n", " ).plot.box(by='medians',\n", " boxprops=dict(linewidth=s),\n", " flierprops=dict(markersize=s),\n", @@ -1417,7 +1486,9 @@ " _ = ax.set_xticklabels(ax.get_xticklabels(),\n", " rotation=45,\n", " horizontalalignment='right')\n", - "\n", + " ax.set_xlabel(f'{args.feat_name_display.capitalize()} binned by their median intensity '\n", + " f'(N {args.feat_name_display})')\n", + " _ = ax.set_ylabel('Frequency')\n", "fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_boxplot_val_train'\n", "figures[fname.stem] = fname\n", "vaep.savefig(ax.get_figure(), fname)" diff --git a/project/01_0_split_data.py b/project/01_0_split_data.py index f41601783..9b4061096 100644 --- a/project/01_0_split_data.py +++ b/project/01_0_split_data.py @@ -19,25 +19,24 @@ # Create data splits # %% +import logging from functools import partial from pathlib import Path -import logging -from typing import Union, List +from typing import List, Union -from IPython.display import display +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt -from sklearn.model_selection import train_test_split import plotly.express as px +from IPython.display import display +from sklearn.model_selection import train_test_split import vaep +import vaep.io.load +from vaep.analyzers import analyzers from vaep.io.datasplits import DataSplits from vaep.sampling import feature_frequency - -from vaep.analyzers import analyzers from vaep.sklearn import get_PCA -import vaep.io.load logger = vaep.logging.setup_nb_logger() logger.info("Split data and make diagnostic plots") @@ -56,7 +55,7 @@ def align_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame): pd.options.display.max_columns = 32 -plt.rcParams['figure.figsize'] = [3, 2] +plt.rcParams['figure.figsize'] = [4, 2] vaep.plotting.make_large_descriptors(7) @@ -93,6 +92,8 @@ def align_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame): frac_non_train: float = 0.1 # fraction of non training data (validation and test split) frac_mnar: float = 0.0 # fraction of missing not at random data, rest: missing completely at random prop_sample_w_sim: float = 1.0 # proportion of samples with simulated missing values +feat_name_display: str = None # display name for feature name (e.g. 'protein group') + # %% args = vaep.nb.get_params(args, globals=globals()) @@ -102,7 +103,6 @@ def align_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame): args = vaep.nb.args_from_dict(args) args - # %% if not 0.0 <= args.frac_mnar <= 1.0: raise ValueError("Invalid MNAR float value (should be betw. 0 and 1):" @@ -140,6 +140,11 @@ def align_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame): ) if args.column_names: df.columns.names = args.column_names +if args.feat_name_display is None: + args.overwrite_entry('feat_name_display', 'features') + if args.column_names: + args.overwrite_entry('feat_name_display', args.column_names[0]) + if not df.index.name: logger.warning("No sample index name found, setting to 'Sample ID'") @@ -158,7 +163,7 @@ def align_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame): .plot .box() ) -ax.set_ylabel('number of observation across samples') +ax.set_ylabel('Frequency') # %% @@ -355,7 +360,7 @@ def join_as_str(seq): # %% group = 1 ax = df.notna().sum(axis=1).hist() -ax.set_xlabel('features per eligable sample') +ax.set_xlabel(f'{args.feat_name_display.capitalize()} per eligable sample') ax.set_ylabel('observations') fname = args.out_figures / f'0_{group}_hist_features_per_sample' figures[fname.stem] = fname @@ -366,7 +371,7 @@ def join_as_str(seq): _new_labels = [l_.get_text().split(';')[0] for l_ in ax.get_xticklabels()] _ = ax.set_xticklabels(_new_labels, rotation=45, horizontalalignment='right') -ax.set_xlabel('feature prevalence') +ax.set_xlabel(f'{args.feat_name_display.capitalize()} prevalence') ax.set_ylabel('observations') fname = args.out_figures / f'0_{group}_feature_prevalence' figures[fname.stem] = fname @@ -380,8 +385,9 @@ def join_as_str(seq): min_max = vaep.plotting.data.min_max(df.stack()) ax, bins = vaep.plotting.data.plot_histogram_intensities( df.stack(), min_max=min_max) - +ax.set_xlabel('Intensity binned') fname = args.out_figures / f'0_{group}_intensity_distribution_overall' + figures[fname.stem] = fname vaep.savefig(ax.get_figure(), fname) @@ -389,21 +395,31 @@ def join_as_str(seq): ax = vaep.plotting.data.plot_feat_median_over_prop_missing( data=df, type='scatter') fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_scatter' +ax.set_xlabel( + f'{args.feat_name_display.capitalize()} binned by their median intensity' + f' (N {args.feat_name_display})') figures[fname.stem] = fname vaep.savefig(ax.get_figure(), fname) # %% -ax = vaep.plotting.data.plot_feat_median_over_prop_missing( - data=df, type='boxplot') +ax, _data_feat_median_over_prop_missing = vaep.plotting.data.plot_feat_median_over_prop_missing( + data=df, type='boxplot', return_plot_data=True) fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_boxplot' +ax.set_xlabel( + f'{args.feat_name_display.capitalize()} binned by their median intensity' + f' (N {args.feat_name_display})') figures[fname.stem] = fname vaep.savefig(ax.get_figure(), fname) +_data_feat_median_over_prop_missing.to_csv(fname.with_suffix('.csv')) +# _data_feat_median_over_prop_missing.to_excel(fname.with_suffix('.xlsx')) +del _data_feat_median_over_prop_missing # %% [markdown] # ### Interactive and Single plots # %% -sample_counts.name = 'identified features' +_feature_display_name = f'identified {args.feat_name_display}' +sample_counts.name = _feature_display_name # %% K = 2 @@ -443,12 +459,12 @@ def join_as_str(seq): # %% fig, ax = plt.subplots() -col_identified_feat = 'identified features' +col_identified_feat = _feature_display_name analyzers.plot_scatter( pcs[pcs_name], ax, pcs[col_identified_feat], - title=f'by {col_identified_feat}', + feat_name_display=args.feat_name_display, size=5, ) fname = (args.out_figures @@ -456,12 +472,16 @@ def join_as_str(seq): figures[fname.stem] = fname vaep.savefig(fig, fname) +# %% +# # ! write principal components to excel (if needed) +# pcs.set_index([df.index.name])[[*pcs_name, col_identified_feat]].to_excel(fname.with_suffix('.xlsx')) + # %% fig = px.scatter( pcs, x=pcs_name[0], y=pcs_name[1], hover_name=pcs_index_name, # hover_data=analysis.df_meta, - title=f'First two Principal Components of {args.M} features for {pcs.shape[0]} samples', + title=f'First two Principal Components of {args.M} {args.feat_name_display} for {pcs.shape[0]} samples', # color=pcs['Software Version'], color=col_identified_feat, template='none', @@ -814,7 +834,8 @@ def join_as_str(seq): ax.legend(_legend[:-1]) if args.use_every_nth_xtick > 1: ax.set_xticks(ax.get_xticks()[::2]) -fname = args.out_figures / f'0_{group}_test_over_train_split.pdf' +ax.set_xlabel('Intensity bins') +fname = args.out_figures / f'0_{group}_val_over_train_split.pdf' figures[fname.name] = fname vaep.savefig(ax.get_figure(), fname) @@ -836,6 +857,11 @@ def join_as_str(seq): figures[fname.name] = fname vaep.savefig(ax.get_figure(), fname) +# %% +counts_per_bin = vaep.pandas.get_counts_per_bin(df=splits_df, bins=bins) +counts_per_bin.to_excel(fname.with_suffix('.xlsx')) +counts_per_bin + # %% ax = splits_df.drop('train', axis=1).plot.hist(bins=bins, xticks=list(bins), @@ -852,6 +878,22 @@ def join_as_str(seq): figures[fname.name] = fname vaep.savefig(ax.get_figure(), fname) +# %% +# Save binned counts + +# %% +counts_per_bin = dict() +for col in splits_df.columns: + _series = (pd.cut(splits_df[col], bins=bins) + .to_frame() + .groupby(col) + .size()) + _series.index.name = 'bin' + counts_per_bin[col] = _series +counts_per_bin = pd.DataFrame(counts_per_bin) +counts_per_bin.to_excel(fname.with_suffix('.xlsx')) +counts_per_bin + # %% [markdown] # plot training data missing plots @@ -886,8 +928,8 @@ def join_as_str(seq): fig, ax = plt.subplots(figsize=(6, 2)) s = 1 s_axes = pd.DataFrame({'medians': medians, - 'validation split': splits.val_y.notna().sum(), - 'training split': splits.train_X.notna().sum()} + 'Validation split': splits.val_y.notna().sum(), + 'Training split': splits.train_X.notna().sum()} ).plot.box(by='medians', boxprops=dict(linewidth=s), flierprops=dict(markersize=s), @@ -896,7 +938,9 @@ def join_as_str(seq): _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right') - + ax.set_xlabel(f'{args.feat_name_display.capitalize()} binned by their median intensity ' + f'(N {args.feat_name_display})') + _ = ax.set_ylabel('Frequency') fname = args.out_figures / f'0_{group}_intensity_median_vs_prop_missing_boxplot_val_train' figures[fname.stem] = fname vaep.savefig(ax.get_figure(), fname) diff --git a/project/01_1_transfer_NAGuideR_pred.ipynb b/project/01_1_transfer_NAGuideR_pred.ipynb index 5e86f0fa5..a985a61c1 100644 --- a/project/01_1_transfer_NAGuideR_pred.ipynb +++ b/project/01_1_transfer_NAGuideR_pred.ipynb @@ -267,7 +267,7 @@ "id": "33fde68c", "metadata": {}, "source": [ - "### Test Datasplit" + "## Test Datasplit" ] }, { diff --git a/project/01_1_transfer_NAGuideR_pred.py b/project/01_1_transfer_NAGuideR_pred.py index 786db0b50..ed152c987 100644 --- a/project/01_1_transfer_NAGuideR_pred.py +++ b/project/01_1_transfer_NAGuideR_pred.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.15.2 +# jupytext_version: 1.15.0 # kernelspec: # display_name: Python 3 # language: python @@ -139,7 +139,7 @@ pd.DataFrame(added_metrics) # %% [markdown] -# ### Test Datasplit +# ## Test Datasplit # %% added_metrics = d_metrics.add_metrics(test_pred_fake_na.dropna(how='all', axis=1), 'test_fake_na') diff --git a/project/01_2_performance_plots.ipynb b/project/01_2_performance_plots.ipynb index 1eed838d0..263026456 100644 --- a/project/01_2_performance_plots.ipynb +++ b/project/01_2_performance_plots.ipynb @@ -28,28 +28,28 @@ "outputs": [], "source": [ "import logging\n", - "import yaml\n", "import random\n", "from pathlib import Path\n", "\n", - "from IPython.display import display\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", + "import yaml\n", + "from IPython.display import display\n", "\n", "import vaep\n", "import vaep.imputation\n", "import vaep.models\n", - "from vaep.models.collect_dumps import collect, select_content\n", - "from vaep.io import datasplits\n", - "from vaep.analyzers import compare_predictions\n", "import vaep.nb\n", + "from vaep.analyzers import compare_predictions\n", + "from vaep.io import datasplits\n", + "from vaep.models.collect_dumps import collect, select_content\n", "\n", "pd.options.display.max_rows = 30\n", "pd.options.display.min_rows = 10\n", "pd.options.display.max_colwidth = 100\n", "\n", - "plt.rcParams.update({'figure.figsize': (3, 2)})\n", + "plt.rcParams.update({'figure.figsize': (4, 2)})\n", "vaep.plotting.make_large_descriptors(7)\n", "\n", "logger = vaep.logging.setup_nb_logger()\n", @@ -122,7 +122,7 @@ "sel_models: str = '' # user defined comparison (comma separated)\n", "# Restrict plotting to top N methods for imputation based on error of validation data, maximum 10\n", "plot_to_n: int = 5\n", - "feat_name_display: str = None # display name for feature name (e.g. 'protein group')\n", + "feat_name_display: str = None # display name for feature name in plural (e.g. 'protein groups')\n", "save_agg_pred: bool = False # save aggregated predictions of validation and test data" ] }, @@ -231,14 +231,17 @@ }, "outputs": [], "source": [ - "fig, axes = plt.subplots(1, 2, sharey=True)\n", + "fig, axes = plt.subplots(1, 2, sharey=True, sharex=True)\n", "\n", "vaep.plotting.data.plot_observations(data.val_y.unstack(), ax=axes[0],\n", - " title='Validation split', size=1)\n", + " title='Validation split', size=1, xlabel='')\n", "vaep.plotting.data.plot_observations(data.test_y.unstack(), ax=axes[1],\n", - " title='Test split', size=1)\n", - "\n", + " title='Test split', size=1, xlabel='')\n", "fig.suptitle(\"Simulated missing values per sample\", size=8)\n", + "# hide axis and use only for common x label\n", + "fig.add_subplot(111, frameon=False)\n", + "plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)\n", + "plt.xlabel(f'Samples ordered by identified {data.val_y.index.names[-1]}')\n", "group = 1\n", "fname = args.out_figures / f'2_{group}_fake_na_val_test_splits.png'\n", "figures[fname.stem] = fname\n", @@ -278,7 +281,9 @@ "outputs": [], "source": [ "prop = freq_feat / len(data.train_X.index.levels[0])\n", - "prop.sort_values().to_frame().plot()" + "prop.sort_values().to_frame().plot(\n", + " xlabel=f'{data.val_y.index.names[-1]}',\n", + " ylabel='Proportion of identification in samples')" ] }, { @@ -551,6 +556,7 @@ "execution_count": null, "id": "a2440887-b5f2-45a1-90cd-d15ef9bfa0a7", "metadata": { + "lines_to_next_cell": 2, "tags": [] }, "outputs": [], @@ -561,39 +567,6 @@ "TOP_N_ORDER" ] }, - { - "cell_type": "markdown", - "id": "dd6818ca-460a-4f14-be85-c4309057e161", - "metadata": {}, - "source": [ - "### Correlation overall" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3aa7831e-ebf3-4de4-af6c-c4b2a8b00373", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "pred_val_corr = pred_val.corr()\n", - "ax = (pred_val_corr\n", - " .loc[TARGET_COL, ORDER_MODELS]\n", - " .plot\n", - " .bar(\n", - " # title='Correlation between Fake NA and model predictions on validation data',\n", - " ylabel='correlation overall'))\n", - "ax = vaep.plotting.add_height_to_barplot(ax)\n", - "ax.set_xticklabels(ax.get_xticklabels(), rotation=45,\n", - " horizontalalignment='right')\n", - "fname = args.out_figures / f'2_{group}_pred_corr_val_overall.pdf'\n", - "figures[fname.stem] = fname\n", - "vaep.savefig(ax.get_figure(), name=fname)\n", - "pred_val_corr" - ] - }, { "cell_type": "markdown", "id": "0ac5f058-c580-4676-83c8-768bdb30f526", @@ -756,11 +729,31 @@ " palette=TOP_N_COLOR_PALETTE,\n", " metric_name=METRIC,)\n", "ax.set_ylabel(f\"Average error ({METRIC})\")\n", + "ax.legend(loc='best', ncols=len(TOP_N_ORDER))\n", "fname = args.out_figures / f'2_{group}_errors_binned_by_feat_median_val.pdf'\n", "figures[fname.stem] = fname\n", "vaep.savefig(ax.get_figure(), name=fname)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f6ffdd5", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# ! only used for reporting\n", + "plotted = vaep.plotting.errors.get_data_for_errors_by_median(\n", + " errors=errors_binned,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", + " metric_name=METRIC\n", + ")\n", + "plotted.to_excel(fname.with_suffix('.xlsx'), index=False)\n", + "plotted" + ] + }, { "cell_type": "code", "execution_count": null, @@ -905,7 +898,7 @@ " COLORS_TO_USE[:top_n],\n", " axes):\n", "\n", - " ax, _ = vaep.plotting.data.plot_histogram_intensities(\n", + " ax, bins = vaep.plotting.data.plot_histogram_intensities(\n", " pred_test[TARGET_COL],\n", " color='grey',\n", " min_max=min_max,\n", @@ -929,36 +922,19 @@ "vaep.savefig(fig, name=fname)" ] }, - { - "cell_type": "markdown", - "id": "116a7b7e", - "metadata": {}, - "source": [ - "### Correlation overall" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "b42efaec-4556-45e9-a813-66da159e771c", - "metadata": { - "tags": [] - }, + "id": "843a917f", + "metadata": {}, "outputs": [], "source": [ - "pred_test_corr = pred_test.corr()\n", - "ax = pred_test_corr.loc[TARGET_COL, ORDER_MODELS].plot.bar(\n", - " # title='Corr. between Fake NA and model predictions on test data',\n", - " ylabel='correlation coefficient overall',\n", - " ylim=(0.7, 1)\n", - ")\n", - "ax = vaep.plotting.add_height_to_barplot(ax)\n", - "ax.set_xticklabels(ax.get_xticklabels(), rotation=45,\n", - " horizontalalignment='right')\n", - "fname = args.out_figures / f'2_{group}_pred_corr_test_overall.pdf'\n", - "figures[fname.stem] = fname\n", - "vaep.savefig(ax.get_figure(), name=fname)\n", - "pred_test_corr" + "counts_per_bin = vaep.pandas.get_counts_per_bin(df=pred_test,\n", + " bins=bins,\n", + " columns=[TARGET_COL, *ORDER_MODELS[:top_n]])\n", + "\n", + "counts_per_bin.to_excel(fname.with_suffix('.xlsx'))\n", + "counts_per_bin" ] }, { @@ -1058,7 +1034,7 @@ "feature_names = pred_test.index.levels[-1]\n", "N_SAMPLES = pred_test.index\n", "M = len(feature_names)\n", - "pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M)]], :]" + "pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M - 1)]], :]" ] }, { @@ -1160,7 +1136,6 @@ "execution_count": null, "id": "9993d145-8b78-4769-838a-01721900a3c7", "metadata": { - "lines_to_next_cell": 0, "tags": [] }, "outputs": [], @@ -1272,15 +1247,14 @@ }, "outputs": [], "source": [ - "\n", - "fig, ax = plt.subplots(figsize=(6, 2))\n", + "fig, ax = plt.subplots(figsize=(4, 2)) # size of the plot can be adjusted\n", "ax = _to_plot.loc[[feature_names.name]].plot.bar(\n", " rot=0,\n", - " ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n", + " ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY}\\n({n_in_comparison:,} intensities)\",\n", " # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n", " color=COLORS_TO_USE,\n", " ax=ax,\n", - " width=.8)\n", + " width=.7)\n", "ax = vaep.plotting.add_height_to_barplot(ax, size=7)\n", "ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=7)\n", "ax.set_xticklabels([])\n", @@ -1326,7 +1300,7 @@ "outputs": [], "source": [ "vaep.plotting.make_large_descriptors(7)\n", - "fig, ax = plt.subplots(figsize=(6, 2))\n", + "fig, ax = plt.subplots(figsize=(8, 2))\n", "\n", "ax, errors_binned = vaep.plotting.errors.plot_errors_by_median(\n", " pred=pred_test[\n", @@ -1338,6 +1312,7 @@ " metric_name=METRIC,\n", " palette=COLORS_TO_USE\n", ")\n", + "ax.legend(loc='best', ncols=len(TOP_N_ORDER))\n", "vaep.plotting.make_large_descriptors(6)\n", "fname = args.out_figures / f'2_{group}_test_errors_binned_by_feat_medians.pdf'\n", "figures[fname.stem] = fname\n", @@ -1348,6 +1323,25 @@ "errors_binned" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1455bcc", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# ! only used for reporting\n", + "plotted = vaep.plotting.errors.get_data_for_errors_by_median(\n", + " errors=errors_binned,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", + " metric_name=METRIC\n", + ")\n", + "plotted.to_excel(fname.with_suffix('.xlsx'), index=False)\n", + "plotted" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1416,10 +1410,12 @@ " color=vaep.plotting.defaults.assign_colors(\n", " list(k.upper() for k in SEL_MODELS)),\n", " ax=ax,\n", - " width=.8)\n", + " width=.7)\n", + " ax.legend(loc='best', ncols=len(SEL_MODELS))\n", " ax = vaep.plotting.add_height_to_barplot(ax, size=5)\n", " ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=5)\n", " ax.set_xticklabels([])\n", + "\n", " fname = args.out_figures / f'2_{group}_performance_test_sel.pdf'\n", " figures[fname.stem] = fname\n", " vaep.savefig(fig, name=fname)\n", @@ -1437,7 +1433,9 @@ "cell_type": "code", "execution_count": null, "id": "2a578570", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "# custom selection\n", @@ -1457,16 +1455,27 @@ " list(k.upper() for k in SEL_MODELS))\n", " )\n", " # ax.set_ylim(0, 1.5)\n", + " ax.legend(loc='best', ncols=len(SEL_MODELS))\n", " # for text in ax.legend().get_texts():\n", " # text.set_fontsize(6)\n", " fname = args.out_figures / f'2_{group}_test_errors_binned_by_feat_medians_sel.pdf'\n", " figures[fname.stem] = fname\n", " vaep.savefig(ax.get_figure(), name=fname)\n", + " plt.show(fig)\n", + "\n", " dumps[fname.stem] = fname.with_suffix('.csv')\n", " errors_binned.to_csv(fname.with_suffix('.csv'))\n", " vaep.plotting.make_large_descriptors(6)\n", " # ax.xaxis.set_tick_params(rotation=0) # horizontal\n", - " display(errors_binned)" + "\n", + " # ! only used for reporting\n", + " plotted = vaep.plotting.errors.get_data_for_errors_by_median(\n", + " errors=errors_binned,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", + " metric_name=METRIC\n", + " )\n", + " plotted.to_excel(fname.with_suffix('.xlsx'), index=False)\n", + " display(plotted)" ] }, { @@ -1497,6 +1506,7 @@ " palette=TOP_N_COLOR_PALETTE,\n", " metric_name=METRIC,\n", ")\n", + "ax.legend(loc='best', ncols=len(TOP_N_ORDER))\n", "fname = args.out_figures / f'2_{group}_test_errors_binned_by_int.pdf'\n", "figures[fname.stem] = fname\n", "vaep.savefig(ax.get_figure(), name=fname)" diff --git a/project/01_2_performance_plots.py b/project/01_2_performance_plots.py index 0922adb72..80f3f7c02 100644 --- a/project/01_2_performance_plots.py +++ b/project/01_2_performance_plots.py @@ -28,28 +28,28 @@ # %% import logging -import yaml import random from pathlib import Path -from IPython.display import display import matplotlib.pyplot as plt import numpy as np import pandas as pd +import yaml +from IPython.display import display import vaep import vaep.imputation import vaep.models -from vaep.models.collect_dumps import collect, select_content -from vaep.io import datasplits -from vaep.analyzers import compare_predictions import vaep.nb +from vaep.analyzers import compare_predictions +from vaep.io import datasplits +from vaep.models.collect_dumps import collect, select_content pd.options.display.max_rows = 30 pd.options.display.min_rows = 10 pd.options.display.max_colwidth = 100 -plt.rcParams.update({'figure.figsize': (3, 2)}) +plt.rcParams.update({'figure.figsize': (4, 2)}) vaep.plotting.make_large_descriptors(7) logger = vaep.logging.setup_nb_logger() @@ -97,7 +97,7 @@ def build_text(s): sel_models: str = '' # user defined comparison (comma separated) # Restrict plotting to top N methods for imputation based on error of validation data, maximum 10 plot_to_n: int = 5 -feat_name_display: str = None # display name for feature name (e.g. 'protein group') +feat_name_display: str = None # display name for feature name in plural (e.g. 'protein groups') save_agg_pred: bool = False # save aggregated predictions of validation and test data @@ -139,14 +139,17 @@ def build_text(s): args.data, file_format=args.file_format) # %% -fig, axes = plt.subplots(1, 2, sharey=True) +fig, axes = plt.subplots(1, 2, sharey=True, sharex=True) vaep.plotting.data.plot_observations(data.val_y.unstack(), ax=axes[0], - title='Validation split', size=1) + title='Validation split', size=1, xlabel='') vaep.plotting.data.plot_observations(data.test_y.unstack(), ax=axes[1], - title='Test split', size=1) - + title='Test split', size=1, xlabel='') fig.suptitle("Simulated missing values per sample", size=8) +# hide axis and use only for common x label +fig.add_subplot(111, frameon=False) +plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False) +plt.xlabel(f'Samples ordered by identified {data.val_y.index.names[-1]}') group = 1 fname = args.out_figures / f'2_{group}_fake_na_val_test_splits.png' figures[fname.stem] = fname @@ -163,7 +166,9 @@ def build_text(s): # %% prop = freq_feat / len(data.train_X.index.levels[0]) -prop.sort_values().to_frame().plot() +prop.sort_values().to_frame().plot( + xlabel=f'{data.val_y.index.names[-1]}', + ylabel='Proportion of identification in samples') # %% [markdown] # View training data in wide format @@ -287,24 +292,6 @@ def build_text(s): color in zip(TOP_N_ORDER, COLORS_TO_USE)} TOP_N_ORDER -# %% [markdown] -# ### Correlation overall - -# %% -pred_val_corr = pred_val.corr() -ax = (pred_val_corr - .loc[TARGET_COL, ORDER_MODELS] - .plot - .bar( - # title='Correlation between Fake NA and model predictions on validation data', - ylabel='correlation overall')) -ax = vaep.plotting.add_height_to_barplot(ax) -ax.set_xticklabels(ax.get_xticklabels(), rotation=45, - horizontalalignment='right') -fname = args.out_figures / f'2_{group}_pred_corr_val_overall.pdf' -figures[fname.stem] = fname -vaep.savefig(ax.get_figure(), name=fname) -pred_val_corr # %% [markdown] # ### Correlation per sample @@ -386,10 +373,22 @@ def build_text(s): palette=TOP_N_COLOR_PALETTE, metric_name=METRIC,) ax.set_ylabel(f"Average error ({METRIC})") +ax.legend(loc='best', ncols=len(TOP_N_ORDER)) fname = args.out_figures / f'2_{group}_errors_binned_by_feat_median_val.pdf' figures[fname.stem] = fname vaep.savefig(ax.get_figure(), name=fname) +# %% +# # ! only used for reporting +plotted = vaep.plotting.errors.get_data_for_errors_by_median( + errors=errors_binned, + feat_name=FEAT_NAME_DISPLAY, + metric_name=METRIC +) +plotted.to_excel(fname.with_suffix('.xlsx'), index=False) +plotted + + # %% errors_binned.head() dumps[fname.stem] = fname.with_suffix('.csv') @@ -454,7 +453,7 @@ def build_text(s): COLORS_TO_USE[:top_n], axes): - ax, _ = vaep.plotting.data.plot_histogram_intensities( + ax, bins = vaep.plotting.data.plot_histogram_intensities( pred_test[TARGET_COL], color='grey', min_max=min_max, @@ -477,23 +476,13 @@ def build_text(s): figures[fname.stem] = fname vaep.savefig(fig, name=fname) -# %% [markdown] -# ### Correlation overall - # %% -pred_test_corr = pred_test.corr() -ax = pred_test_corr.loc[TARGET_COL, ORDER_MODELS].plot.bar( - # title='Corr. between Fake NA and model predictions on test data', - ylabel='correlation coefficient overall', - ylim=(0.7, 1) -) -ax = vaep.plotting.add_height_to_barplot(ax) -ax.set_xticklabels(ax.get_xticklabels(), rotation=45, - horizontalalignment='right') -fname = args.out_figures / f'2_{group}_pred_corr_test_overall.pdf' -figures[fname.stem] = fname -vaep.savefig(ax.get_figure(), name=fname) -pred_test_corr +counts_per_bin = vaep.pandas.get_counts_per_bin(df=pred_test, + bins=bins, + columns=[TARGET_COL, *ORDER_MODELS[:top_n]]) + +counts_per_bin.to_excel(fname.with_suffix('.xlsx')) +counts_per_bin # %% [markdown] # ### Correlation per sample @@ -547,7 +536,7 @@ def build_text(s): feature_names = pred_test.index.levels[-1] N_SAMPLES = pred_test.index M = len(feature_names) -pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M)]], :] +pred_test.loc[pd.IndexSlice[:, feature_names[random.randint(0, M - 1)]], :] # %% options = random.sample(set(feature_names), 1) @@ -616,6 +605,7 @@ def highlight_min(s, color, tolerence=0.00001): ) else: print("None found") + # %% [markdown] # ### Error plot @@ -650,15 +640,14 @@ def highlight_min(s, color, tolerence=0.00001): # %% - -fig, ax = plt.subplots(figsize=(6, 2)) +fig, ax = plt.subplots(figsize=(4, 2)) # size of the plot can be adjusted ax = _to_plot.loc[[feature_names.name]].plot.bar( rot=0, - ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)", + ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY}\n({n_in_comparison:,} intensities)", # title=f'performance on test data (based on {n_in_comparison:,} measurements)', color=COLORS_TO_USE, ax=ax, - width=.8) + width=.7) ax = vaep.plotting.add_height_to_barplot(ax, size=7) ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=7) ax.set_xticklabels([]) @@ -681,7 +670,7 @@ def highlight_min(s, color, tolerence=0.00001): # %% vaep.plotting.make_large_descriptors(7) -fig, ax = plt.subplots(figsize=(6, 2)) +fig, ax = plt.subplots(figsize=(8, 2)) ax, errors_binned = vaep.plotting.errors.plot_errors_by_median( pred=pred_test[ @@ -693,6 +682,7 @@ def highlight_min(s, color, tolerence=0.00001): metric_name=METRIC, palette=COLORS_TO_USE ) +ax.legend(loc='best', ncols=len(TOP_N_ORDER)) vaep.plotting.make_large_descriptors(6) fname = args.out_figures / f'2_{group}_test_errors_binned_by_feat_medians.pdf' figures[fname.stem] = fname @@ -702,6 +692,17 @@ def highlight_min(s, color, tolerence=0.00001): errors_binned.to_csv(fname.with_suffix('.csv')) errors_binned +# %% +# # ! only used for reporting +plotted = vaep.plotting.errors.get_data_for_errors_by_median( + errors=errors_binned, + feat_name=FEAT_NAME_DISPLAY, + metric_name=METRIC +) +plotted.to_excel(fname.with_suffix('.xlsx'), index=False) +plotted + + # %% (errors_binned .set_index( @@ -748,10 +749,12 @@ def highlight_min(s, color, tolerence=0.00001): color=vaep.plotting.defaults.assign_colors( list(k.upper() for k in SEL_MODELS)), ax=ax, - width=.8) + width=.7) + ax.legend(loc='best', ncols=len(SEL_MODELS)) ax = vaep.plotting.add_height_to_barplot(ax, size=5) ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=5) ax.set_xticklabels([]) + fname = args.out_figures / f'2_{group}_performance_test_sel.pdf' figures[fname.stem] = fname vaep.savefig(fig, name=fname) @@ -783,16 +786,28 @@ def highlight_min(s, color, tolerence=0.00001): list(k.upper() for k in SEL_MODELS)) ) # ax.set_ylim(0, 1.5) + ax.legend(loc='best', ncols=len(SEL_MODELS)) # for text in ax.legend().get_texts(): # text.set_fontsize(6) fname = args.out_figures / f'2_{group}_test_errors_binned_by_feat_medians_sel.pdf' figures[fname.stem] = fname vaep.savefig(ax.get_figure(), name=fname) + plt.show(fig) + dumps[fname.stem] = fname.with_suffix('.csv') errors_binned.to_csv(fname.with_suffix('.csv')) vaep.plotting.make_large_descriptors(6) # ax.xaxis.set_tick_params(rotation=0) # horizontal - display(errors_binned) + + # # ! only used for reporting + plotted = vaep.plotting.errors.get_data_for_errors_by_median( + errors=errors_binned, + feat_name=FEAT_NAME_DISPLAY, + metric_name=METRIC + ) + plotted.to_excel(fname.with_suffix('.xlsx'), index=False) + display(plotted) + # %% [markdown] # ### Error by non-decimal number of intensity @@ -809,6 +824,7 @@ def highlight_min(s, color, tolerence=0.00001): palette=TOP_N_COLOR_PALETTE, metric_name=METRIC, ) +ax.legend(loc='best', ncols=len(TOP_N_ORDER)) fname = args.out_figures / f'2_{group}_test_errors_binned_by_int.pdf' figures[fname.stem] = fname vaep.savefig(ax.get_figure(), name=fname) diff --git a/project/03_1_best_models_comparison.ipynb b/project/03_1_best_models_comparison.ipynb index 6a6f6f6fe..af3948834 100644 --- a/project/03_1_best_models_comparison.ipynb +++ b/project/03_1_best_models_comparison.ipynb @@ -6,18 +6,18 @@ "metadata": {}, "outputs": [], "source": [ + "import logging\n", "from pathlib import Path\n", - "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", "import seaborn as sns\n", "\n", - "import vaep.pandas\n", "import vaep.nb\n", - "\n", - "import logging\n", + "import vaep.pandas\n", "import vaep.plotting\n", "from vaep.logging import setup_logger\n", + "\n", "logger = setup_logger(logger=logging.getLogger('vaep'), level=10)\n", "\n", "plt.rcParams['figure.figsize'] = [4.0, 2.0]\n", @@ -164,16 +164,34 @@ "metadata": {}, "outputs": [], "source": [ - "ax = sns.barplot(x='data level',\n", - " y='MAE',\n", - " hue='model',\n", - " order=IDX[0],\n", - " palette=vaep.plotting.defaults.color_model_mapping,\n", - " ci=95,\n", - " errwidth=1.5,\n", - " data=view_long)\n", - "ax.set_xlabel('')\n", - "fig = ax.get_figure()" + "# individual points overlaid on bar plot:\n", + "# seaborn 12.2\n", + "# https://stackoverflow.com/a/69398767/9684872\n", + "sns.set_theme(context='paper', ) # font_scale=.8)\n", + "sns.set_style(\"whitegrid\")\n", + "g = sns.catplot(x=\"data level\", y=\"MAE\", hue='model', data=view_long,\n", + " kind=\"bar\",\n", + " errorbar=\"ci\", # ! 95% confidence interval bootstrapped (using 1000 draws by default)\n", + " edgecolor=\"black\",\n", + " errcolor=\"black\",\n", + " hue_order=IDX[1],\n", + " order=IDX[0],\n", + " palette=vaep.plotting.defaults.color_model_mapping,\n", + " alpha=0.9,\n", + " height=2, # set the height of the figure\n", + " aspect=1.8 # set the aspect ratio of the figure\n", + " )\n", + "\n", + "# map data to stripplot\n", + "g.map(sns.stripplot, 'data level', 'MAE', 'model',\n", + " hue_order=IDX[1], order=IDX[0],\n", + " palette=vaep.plotting.defaults.color_model_mapping,\n", + " dodge=True, alpha=1, ec='k', linewidth=1,\n", + " s=2)\n", + "\n", + "fig = g.figure\n", + "ax = fig.get_axes()[0]\n", + "_ = ax.set_xlabel('')" ] }, { @@ -182,7 +200,7 @@ "metadata": {}, "outputs": [], "source": [ - "vaep.savefig(fig, FOLDER / \"model_performance_repeated_runs.pdf\")" + "vaep.savefig(fig, FOLDER / \"model_performance_repeated_runs.pdf\", tight_layout=False)" ] }, { diff --git a/project/03_1_best_models_comparison.py b/project/03_1_best_models_comparison.py index 97c54d8b5..6a2daadce 100644 --- a/project/03_1_best_models_comparison.py +++ b/project/03_1_best_models_comparison.py @@ -14,18 +14,18 @@ # --- # %% +import logging from pathlib import Path -import pandas as pd import matplotlib.pyplot as plt +import pandas as pd import seaborn as sns -import vaep.pandas import vaep.nb - -import logging +import vaep.pandas import vaep.plotting from vaep.logging import setup_logger + logger = setup_logger(logger=logging.getLogger('vaep'), level=10) plt.rcParams['figure.figsize'] = [4.0, 2.0] @@ -102,19 +102,37 @@ view_long # %% -ax = sns.barplot(x='data level', - y='MAE', - hue='model', - order=IDX[0], - palette=vaep.plotting.defaults.color_model_mapping, - ci=95, - errwidth=1.5, - data=view_long) -ax.set_xlabel('') -fig = ax.get_figure() +# individual points overlaid on bar plot: +# seaborn 12.2 +# https://stackoverflow.com/a/69398767/9684872 +sns.set_theme(context='paper', ) # font_scale=.8) +sns.set_style("whitegrid") +g = sns.catplot(x="data level", y="MAE", hue='model', data=view_long, + kind="bar", + errorbar="ci", # ! 95% confidence interval bootstrapped (using 1000 draws by default) + edgecolor="black", + errcolor="black", + hue_order=IDX[1], + order=IDX[0], + palette=vaep.plotting.defaults.color_model_mapping, + alpha=0.9, + height=2, # set the height of the figure + aspect=1.8 # set the aspect ratio of the figure + ) + +# map data to stripplot +g.map(sns.stripplot, 'data level', 'MAE', 'model', + hue_order=IDX[1], order=IDX[0], + palette=vaep.plotting.defaults.color_model_mapping, + dodge=True, alpha=1, ec='k', linewidth=1, + s=2) + +fig = g.figure +ax = fig.get_axes()[0] +_ = ax.set_xlabel('') # %% -vaep.savefig(fig, FOLDER / "model_performance_repeated_runs.pdf") +vaep.savefig(fig, FOLDER / "model_performance_repeated_runs.pdf", tight_layout=False) # %% writer.close() diff --git a/project/10_7_ald_reduced_dataset_plots.ipynb b/project/10_7_ald_reduced_dataset_plots.ipynb index b27ea2db4..bdd4a796a 100644 --- a/project/10_7_ald_reduced_dataset_plots.ipynb +++ b/project/10_7_ald_reduced_dataset_plots.ipynb @@ -38,6 +38,10 @@ "COLORS_TO_USE_MAPPTING = vaep.plotting.defaults.color_model_mapping\n", "COLORS_TO_USE_MAPPTING[NONE_COL_NAME] = COLORS_TO_USE_MAPPTING['None']\n", "\n", + "COLORS_CONTIGENCY_TABLE = {\n", + " k: f'C{i}' for i, k in enumerate(['FP', 'TN', 'TP', 'FN'])\n", + "}\n", + "\n", "\n", "def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05,\n", " alpha=1.0, style='.', markersize=3):\n", @@ -219,7 +223,7 @@ ")\n", "sel = qvalues_sel.loc[mask_lost_sign.squeeze()]\n", "sel.columns = sel.columns.droplevel(-1)\n", - "sel = sel[ORDER_MODELS + [REF_MODEL]]\n", + "sel = sel[ORDER_MODELS + [REF_MODEL]].sort_values(REF_MODEL)\n", "sel.to_excel(writer, sheet_name='lost_signal_qvalues')\n", "sel" ] @@ -242,9 +246,12 @@ ").droplevel(-1, axis=1)\n", ")\n", "da_target_sel_counts = njab.pandas.combine_value_counts(da_target_sel_counts)\n", - "ax = da_target_sel_counts.T.plot.bar(ylabel='count')\n", + "ax = da_target_sel_counts.T.plot.bar(ylabel='count',\n", + " color=[COLORS_CONTIGENCY_TABLE['FN'],\n", + " COLORS_CONTIGENCY_TABLE['TP']])\n", "ax.locator_params(axis='y', integer=True)\n", "fname = out_folder / 'lost_signal_da_counts.pdf'\n", + "da_target_sel_counts.fillna(0).to_excel(writer, sheet_name=fname.stem)\n", "files_out[fname.name] = fname.as_posix()\n", "vaep.savefig(ax.figure, fname)" ] @@ -292,7 +299,7 @@ ")\n", "sel = qvalues_sel.loc[mask_gained_signal.squeeze()]\n", "sel.columns = sel.columns.droplevel(-1)\n", - "sel = sel[ORDER_MODELS + [REF_MODEL]]\n", + "sel = sel[ORDER_MODELS + [REF_MODEL]].sort_values(REF_MODEL)\n", "sel.to_excel(writer, sheet_name='gained_signal_qvalues')\n", "sel" ] @@ -313,9 +320,12 @@ ").droplevel(-1, axis=1)\n", ")\n", "da_target_sel_counts = njab.pandas.combine_value_counts(da_target_sel_counts)\n", - "ax = da_target_sel_counts.T.plot.bar(ylabel='count')\n", + "ax = da_target_sel_counts.T.plot.bar(ylabel='count',\n", + " color=[COLORS_CONTIGENCY_TABLE['TN'],\n", + " COLORS_CONTIGENCY_TABLE['FP']])\n", "ax.locator_params(axis='y', integer=True)\n", "fname = out_folder / 'gained_signal_da_counts.pdf'\n", + "da_target_sel_counts.fillna(0).to_excel(writer, sheet_name=fname.stem)\n", "files_out[fname.name] = fname.as_posix()\n", "vaep.savefig(ax.figure, fname)" ] diff --git a/project/10_7_ald_reduced_dataset_plots.py b/project/10_7_ald_reduced_dataset_plots.py index 1169daa3b..30c2750f2 100644 --- a/project/10_7_ald_reduced_dataset_plots.py +++ b/project/10_7_ald_reduced_dataset_plots.py @@ -25,6 +25,10 @@ COLORS_TO_USE_MAPPTING = vaep.plotting.defaults.color_model_mapping COLORS_TO_USE_MAPPTING[NONE_COL_NAME] = COLORS_TO_USE_MAPPTING['None'] +COLORS_CONTIGENCY_TABLE = { + k: f'C{i}' for i, k in enumerate(['FP', 'TN', 'TP', 'FN']) +} + def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05, alpha=1.0, style='.', markersize=3): @@ -111,7 +115,7 @@ def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05, ) sel = qvalues_sel.loc[mask_lost_sign.squeeze()] sel.columns = sel.columns.droplevel(-1) -sel = sel[ORDER_MODELS + [REF_MODEL]] +sel = sel[ORDER_MODELS + [REF_MODEL]].sort_values(REF_MODEL) sel.to_excel(writer, sheet_name='lost_signal_qvalues') sel @@ -127,9 +131,12 @@ def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05, ).droplevel(-1, axis=1) ) da_target_sel_counts = njab.pandas.combine_value_counts(da_target_sel_counts) -ax = da_target_sel_counts.T.plot.bar(ylabel='count') +ax = da_target_sel_counts.T.plot.bar(ylabel='count', + color=[COLORS_CONTIGENCY_TABLE['FN'], + COLORS_CONTIGENCY_TABLE['TP']]) ax.locator_params(axis='y', integer=True) fname = out_folder / 'lost_signal_da_counts.pdf' +da_target_sel_counts.fillna(0).to_excel(writer, sheet_name=fname.stem) files_out[fname.name] = fname.as_posix() vaep.savefig(ax.figure, fname) @@ -157,7 +164,7 @@ def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05, ) sel = qvalues_sel.loc[mask_gained_signal.squeeze()] sel.columns = sel.columns.droplevel(-1) -sel = sel[ORDER_MODELS + [REF_MODEL]] +sel = sel[ORDER_MODELS + [REF_MODEL]].sort_values(REF_MODEL) sel.to_excel(writer, sheet_name='gained_signal_qvalues') sel @@ -171,9 +178,12 @@ def plot_qvalues(df, x: str, y: list, ax=None, cutoff=0.05, ).droplevel(-1, axis=1) ) da_target_sel_counts = njab.pandas.combine_value_counts(da_target_sel_counts) -ax = da_target_sel_counts.T.plot.bar(ylabel='count') +ax = da_target_sel_counts.T.plot.bar(ylabel='count', + color=[COLORS_CONTIGENCY_TABLE['TN'], + COLORS_CONTIGENCY_TABLE['FP']]) ax.locator_params(axis='y', integer=True) fname = out_folder / 'gained_signal_da_counts.pdf' +da_target_sel_counts.fillna(0).to_excel(writer, sheet_name=fname.stem) files_out[fname.name] = fname.as_posix() vaep.savefig(ax.figure, fname) diff --git a/project/misc_illustrations.ipynb b/project/misc_illustrations.ipynb index 02012ad42..f7365c2fd 100644 --- a/project/misc_illustrations.ipynb +++ b/project/misc_illustrations.ipynb @@ -14,6 +14,7 @@ "outputs": [], "source": [ "from pathlib import Path\n", + "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import scipy.stats" @@ -58,15 +59,13 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "lines_to_next_cell": 0 - }, + "metadata": {}, "outputs": [], "source": [ "mu = 25.0\n", "stddev = 1.0\n", "\n", - "x = np.linspace(mu - 5, mu + 5, num=101)\n", + "x = np.linspace(mu - 3, mu + 3, num=101)\n", "\n", "y_normal = scipy.stats.norm.pdf(x, loc=mu, scale=stddev)\n", "\n", @@ -82,17 +81,19 @@ "for y, c in zip([y_normal, y_impute], colors):\n", " ax.plot(x, y, color=c,)\n", " ax.fill_between(x, y, color=c)\n", - " ax.set_xlabel('log2 intensity')\n", - " ax.set_ylabel('density')\n", - " ax.set_label(\"test\")\n", - " ax.legend([\"original\", \"down shifted\"])\n", + "ax.set_xlabel('log2 intensity')\n", + "ax.set_ylabel('density')\n", + "ax.set_label(\"test\")\n", + "ax.legend([\"original\", \"down shifted\"])\n", "fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ "\n", @@ -101,6 +102,34 @@ "fig.savefig(FIGUREFOLDER / 'illustration_normal_imputation_highres', dpi=600)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "3646ec23", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'xtick.labelsize': 'large',\n", + " 'ytick.labelsize': 'large',\n", + " 'axes.titlesize': 'large',\n", + " 'axes.labelsize': 'large',\n", + " })\n", + "fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n", + "\n", + "for y, c in zip([y_normal], colors):\n", + " ax.plot(x, y, color=c,)\n", + " # ax.fill_between(x, y, color=c)\n", + " ax.set_xlabel('log2 intensity')\n", + " ax.set_ylabel('density')\n", + " ax.set_label(\"test\")\n", + " # ax.legend([\"original\", \"down shifted\"])\n", + "fig.tight_layout()\n", + "\n", + "fig.savefig(FIGUREFOLDER / 'illustration_normal')\n", + "fig.savefig(FIGUREFOLDER / 'illustration_normal.pdf')\n", + "fig.savefig(FIGUREFOLDER / 'illustration_normal_highres', dpi=600)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -187,6 +216,78 @@ { "cell_type": "markdown", "metadata": {}, + "source": [ + "## Volcano plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7883109", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# Sample data for the volcano plot\n", + "np.random.seed(42)\n", + "fold_change = np.random.default_rng().normal(0, 1, 1000)\n", + "p_value = np.random.default_rng().uniform(0, 1, 1000)\n", + "\n", + "# Volcano plot\n", + "# Assuming you have two arrays, fold_change and p_value, containing the fold change values and p-values respectively\n", + "\n", + "# Set the significance threshold for p-value\n", + "significance_threshold = 0.05\n", + "\n", + "# Set the fold change threshold\n", + "fold_change_threshold = 2\n", + "\n", + "# Create a boolean mask for significant points\n", + "significant_mask = (p_value < significance_threshold)\n", + "\n", + "# Create a boolean mask for points that meet the fold change threshold\n", + "fold_change_mask = (abs(fold_change) > fold_change_threshold)\n", + "\n", + "# Combine the masks to get the final mask for significant points\n", + "final_mask = significant_mask & fold_change_mask\n", + "\n", + "fig, ax = plt.subplots(1, 1, figsize=(3, 3))\n", + "# Plot the volcano plot\n", + "_ = ax.scatter(fold_change, -np.log10(p_value), c='gray', alpha=0.5, s=10)\n", + "_ = ax.scatter(fold_change[final_mask], -np.log10(p_value[final_mask]), c='red', alpha=0.7, s=20)\n", + "\n", + "# Add labels and title\n", + "_ = ax.set_xlabel('Log2 fold change')\n", + "_ = ax.set_ylabel('-log10(p-value)')\n", + "_ = ax.set_title('Volcano Plot')\n", + "\n", + "# Add significance threshold lines\n", + "_ = ax.axhline(-np.log10(significance_threshold), color='black', linestyle='--')\n", + "_ = ax.axvline(fold_change_threshold, color='black', linestyle='--')\n", + "_ = ax.axvline(-fold_change_threshold, color='black', linestyle='--')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b65b4e01", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "fig.tight_layout()\n", + "fig.savefig(FIGUREFOLDER / 'illustration_volcano.png', dpi=300)\n", + "fig.savefig(FIGUREFOLDER / 'illustration_volcano.pdf') " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c324f9d8", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/project/misc_illustrations.py b/project/misc_illustrations.py index cdfd362fb..2e5c50b9f 100644 --- a/project/misc_illustrations.py +++ b/project/misc_illustrations.py @@ -5,7 +5,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.15.2 +# jupytext_version: 1.15.0 # kernelspec: # display_name: Python 3 # language: python @@ -17,6 +17,7 @@ # %% from pathlib import Path + import matplotlib.pyplot as plt import numpy as np import scipy.stats @@ -44,7 +45,7 @@ mu = 25.0 stddev = 1.0 -x = np.linspace(mu - 5, mu + 5, num=101) +x = np.linspace(mu - 3, mu + 3, num=101) y_normal = scipy.stats.norm.pdf(x, loc=mu, scale=stddev) @@ -60,11 +61,12 @@ for y, c in zip([y_normal, y_impute], colors): ax.plot(x, y, color=c,) ax.fill_between(x, y, color=c) - ax.set_xlabel('log2 intensity') - ax.set_ylabel('density') - ax.set_label("test") - ax.legend(["original", "down shifted"]) +ax.set_xlabel('log2 intensity') +ax.set_ylabel('density') +ax.set_label("test") +ax.legend(["original", "down shifted"]) fig.tight_layout() + # %% fig.savefig(FIGUREFOLDER / 'illustration_normal_imputation') @@ -72,6 +74,28 @@ fig.savefig(FIGUREFOLDER / 'illustration_normal_imputation_highres', dpi=600) +# %% +plt.rcParams.update({'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'axes.titlesize': 'large', + 'axes.labelsize': 'large', + }) +fig, ax = plt.subplots(1, 1, figsize=(3, 2)) + +for y, c in zip([y_normal], colors): + ax.plot(x, y, color=c,) + # ax.fill_between(x, y, color=c) + ax.set_xlabel('log2 intensity') + ax.set_ylabel('density') + ax.set_label("test") + # ax.legend(["original", "down shifted"]) +fig.tight_layout() + +fig.savefig(FIGUREFOLDER / 'illustration_normal') +fig.savefig(FIGUREFOLDER / 'illustration_normal.pdf') +fig.savefig(FIGUREFOLDER / 'illustration_normal_highres', dpi=600) + + # %% [markdown] # ## Log transformations and errors # @@ -127,4 +151,50 @@ def rel_error(measurment, log_error, other_measurment): # whereas the error in the original space is the same # %% [markdown] -# +# ## Volcano plot + +# %% +# Sample data for the volcano plot +np.random.seed(42) +fold_change = np.random.default_rng().normal(0, 1, 1000) +p_value = np.random.default_rng().uniform(0, 1, 1000) + +# Volcano plot +# Assuming you have two arrays, fold_change and p_value, containing the fold change values and p-values respectively + +# Set the significance threshold for p-value +significance_threshold = 0.05 + +# Set the fold change threshold +fold_change_threshold = 2 + +# Create a boolean mask for significant points +significant_mask = (p_value < significance_threshold) + +# Create a boolean mask for points that meet the fold change threshold +fold_change_mask = (abs(fold_change) > fold_change_threshold) + +# Combine the masks to get the final mask for significant points +final_mask = significant_mask & fold_change_mask + +fig, ax = plt.subplots(1, 1, figsize=(3, 3)) +# Plot the volcano plot +_ = ax.scatter(fold_change, -np.log10(p_value), c='gray', alpha=0.5, s=10) +_ = ax.scatter(fold_change[final_mask], -np.log10(p_value[final_mask]), c='red', alpha=0.7, s=20) + +# Add labels and title +_ = ax.set_xlabel('Log2 fold change') +_ = ax.set_ylabel('-log10(p-value)') +_ = ax.set_title('Volcano Plot') + +# Add significance threshold lines +_ = ax.axhline(-np.log10(significance_threshold), color='black', linestyle='--') +_ = ax.axvline(fold_change_threshold, color='black', linestyle='--') +_ = ax.axvline(-fold_change_threshold, color='black', linestyle='--') + + +# %% +fig.tight_layout() +fig.savefig(FIGUREFOLDER / 'illustration_volcano.png', dpi=300) +fig.savefig(FIGUREFOLDER / 'illustration_volcano.pdf') +# %% diff --git a/setup.cfg b/setup.cfg index b987792a2..eb5b835d9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ install_requires = torch scikit-learn>=1.0 scipy - seaborn + seaborn<0.13 fastai omegaconf tqdm diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fcde11887..3e199940d 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,5 @@ import numpy as np - from vaep.utils import create_random_missing_data @@ -8,4 +7,4 @@ def test_create_random_missing_data(): data = create_random_missing_data(N=43, M=13, prop_missing=0.2) assert data.shape == (43, 13) assert np.isnan(data).sum() - assert abs((float(np.isnan(data).sum()) / (43 * 13)) - 0.2) < 0.05 + assert abs((float(np.isnan(data).sum()) / (43 * 13)) - 0.2) < 0.1 diff --git a/vaep/analyzers/analyzers.py b/vaep/analyzers/analyzers.py index 999d83321..c807c58af 100644 --- a/vaep/analyzers/analyzers.py +++ b/vaep/analyzers/analyzers.py @@ -1,30 +1,23 @@ +import logging +import random from collections import namedtuple from pathlib import Path from types import SimpleNamespace -from typing import Tuple, Union, List - -import logging -import random - +from typing import List, Optional, Tuple, Union +import matplotlib.dates as mdates +import matplotlib.pyplot as plt import numpy as np import pandas as pd - -import matplotlib.pyplot as plt -import matplotlib.dates as mdates import seaborn - -from sklearn.decomposition import PCA -from sklearn.impute import SimpleImputer - from njab.sklearn import run_pca +from sklearn.impute import SimpleImputer import vaep from vaep.analyzers import Analysis - -from vaep.pandas import _add_indices from vaep.io.datasplits import long_format, wide_format from vaep.io.load import verify_df +from vaep.pandas import _add_indices logger = logging.getLogger(__name__) @@ -450,18 +443,26 @@ def plot_date_map(df, ax, def plot_scatter(df, ax, meta: pd.Series, - title: str = 'by some metadata', + feat_name_display: str = 'features', + title: Optional[str] = None, alpha=ALPHA, fontsize=8, size=2): cols = list(df.columns) assert len(cols) == 2, f'Please provide two dimensons, not {df.columns}' + if not title: + title = f'by identified {feat_name_display}' ax.set_title(title, fontsize=fontsize) ax.set_xlabel(cols[0]) ax.set_ylabel(cols[1]) path_collection = ax.scatter( x=cols[0], y=cols[1], s=size, c=meta, data=df, alpha=alpha) - cbar = ax.get_figure().colorbar(path_collection, ax=ax) + _ = ax.get_figure().colorbar(path_collection, ax=ax, + label=f'Identified {feat_name_display}', + # ticklocation='left', # ignored by matplotlib + location='right', # ! left does not put colobar without overlapping y ticks + format="{x:,.0f}", + ) def seaborn_scatter(df, ax, diff --git a/vaep/pandas/__init__.py b/vaep/pandas/__init__.py index e433e6763..ffaa60b17 100644 --- a/vaep/pandas/__init__.py +++ b/vaep/pandas/__init__.py @@ -1,7 +1,7 @@ import collections.abc from collections import namedtuple from types import SimpleNamespace -from typing import Iterable +from typing import Iterable, List, Optional import numpy as np import omegaconf @@ -283,3 +283,19 @@ def get_lower_whiskers(df: pd.DataFrame, factor: float = 1.5) -> pd.Series: iqr = ret.loc['75%'] - ret.loc['25%'] ret = ret.loc['25%'] - iqr * factor return ret + + +def get_counts_per_bin(df: pd.DataFrame, bins: range, columns: Optional[List[str]] = None): + """Return counts per bin for selected columns in DataFrame.""" + counts_per_bin = dict() + if columns is None: + columns = df.columns.to_list() + for col in columns: + _series = (pd.cut(df[col], bins=bins) + .to_frame() + .groupby(col) + .size()) + _series.index.name = 'bin' + counts_per_bin[col] = _series + counts_per_bin = pd.DataFrame(counts_per_bin) + return counts_per_bin diff --git a/vaep/plotting/__init__.py b/vaep/plotting/__init__.py index 38e39c8d4..5053a7d66 100644 --- a/vaep/plotting/__init__.py +++ b/vaep/plotting/__init__.py @@ -45,14 +45,16 @@ def _savefig(fig, name, folder: pathlib.Path = '.', pdf=True, - dpi=300 # default 'figure' + dpi=300, # default 'figure', + tight_layout=True, ): """Save matplotlib Figure (having method `savefig`) as pdf and png.""" folder = pathlib.Path(folder) fname = folder / name folder = fname.parent # in case name specifies folders folder.mkdir(exist_ok=True, parents=True) - fig.tight_layout() + if tight_layout: + fig.tight_layout() fig.savefig(fname.with_suffix('.png'), dpi=dpi) if pdf: fig.savefig(fname.with_suffix('.pdf'), dpi=dpi) diff --git a/vaep/plotting/data.py b/vaep/plotting/data.py index ecc9ffd00..edc284529 100644 --- a/vaep/plotting/data.py +++ b/vaep/plotting/data.py @@ -1,13 +1,12 @@ """Plot data distribution based on pandas `DataFrames` or `Series`.""" import logging -from typing import Tuple, Iterable +from typing import Iterable, Tuple, Union import matplotlib import matplotlib.pyplot as plt -from matplotlib.axes import Axes import pandas as pd import seaborn as sns - +from matplotlib.axes import Axes logger = logging.getLogger(__name__) @@ -266,12 +265,15 @@ def plot_missing_pattern_histogram(data: pd.DataFrame, def plot_feat_median_over_prop_missing(data: pd.DataFrame, type: str = 'scatter', - ax=None, - s=1) -> matplotlib.axes.Axes: + ax: matplotlib.axes.Axes = None, + s: int = 1, + return_plot_data: bool = False + ) -> Union[matplotlib.axes.Axes, + Tuple[matplotlib.axes.Axes, pd.DataFrame]]: """Plot feature median over proportion missing in that feature. Sorted by feature median into bins.""" y_col = 'prop. missing' - x_col = 'Feature median intensity binned (based on N feature medians)' + x_col = 'Features binned by their median intensity (N features)' missing_by_median = { 'median feat value': data.median(), @@ -305,8 +307,8 @@ def plot_feat_median_over_prop_missing(data: pd.DataFrame, # # for some reason this does not work as it does elswhere: # _ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45) # # do it manually: - _ = [(l.set_rotation(45), l.set_horizontalalignment('right')) - for l in ax.get_xticklabels()] + _ = [(_l.set_rotation(45), _l.set_horizontalalignment('right')) + for _l in ax.get_xticklabels()] elif type == 'boxplot': ax = missing_by_median[[x_col, y_col]].plot.box( by=x_col, @@ -324,4 +326,6 @@ def plot_feat_median_over_prop_missing(data: pd.DataFrame, else: raise ValueError( f'Unknown plot type: {type}, choose from: scatter, boxplot') + if return_plot_data: + return ax, missing_by_median return ax diff --git a/vaep/plotting/errors.py b/vaep/plotting/errors.py index 5326b9d86..cdfeed140 100644 --- a/vaep/plotting/errors.py +++ b/vaep/plotting/errors.py @@ -1,10 +1,15 @@ """Plot errors based on DataFrame with model predictions.""" from __future__ import annotations -import pandas as pd +import itertools from typing import Optional -from matplotlib.axes import Axes + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import seaborn as sns +from matplotlib.axes import Axes +from seaborn.categorical import _BarPlotter import vaep.pandas.calc_errors @@ -104,6 +109,34 @@ def plot_errors_by_median(pred: pd.DataFrame, return ax, errors +def get_data_for_errors_by_median(errors: pd.DataFrame, feat_name, metric_name): + """Extract Bars with confidence intervals from seaborn plot. + Confident intervals are calculated with bootstrapping (sampling the mean). + + Relies on internal seaborn class. only used for reporting of source data in the paper. + """ + x_axis_name = f'intensity binned by median of {feat_name}' + + plotter = _BarPlotter(data=errors, x=x_axis_name, y=metric_name, hue='model', + order=None, hue_order=None, + estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None, + orient=None, color=None, palette=None, saturation=.75, width=.8, + errcolor=".26", errwidth=None, capsize=None, dodge=True) + ax = plt.gca() + plotter.plot(ax, {}) + plt.close(ax.get_figure()) + mean, cf_interval = plotter.statistic.flatten(), plotter.confint.reshape(-1, 2) + plotted = pd.DataFrame(np.concatenate((mean.reshape(-1, 1), cf_interval), axis=1), columns=[ + 'mean', 'ci_low', 'ci_high']) + _index = pd.DataFrame(list(itertools.product( + (_l.get_text() for _l in ax.get_xticklabels()), # bins x-axis + (_l.get_text() for _l in ax.get_legend().get_texts()), # models legend + ) + ), columns=['bin', 'model']) + plotted = pd.concat([_index, plotted], axis=1) + return plotted + + def plot_rolling_error(errors: pd.DataFrame, metric_name: str, window: int = 200, min_freq=None, freq_col: str = 'freq', colors_to_use=None, ax=None):