From e49c1eb4793e08e69480a63582fd90ce59f74144 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 28 Nov 2023 16:52:38 +0100 Subject: [PATCH] :art: allow custom display name of feat --- project/01_2_performance_plots.ipynb | 54 +++++++++++++++------------- project/01_2_performance_plots.py | 34 ++++++++++-------- vaep/plotting/errors.py | 8 +++-- 3 files changed, 55 insertions(+), 41 deletions(-) diff --git a/project/01_2_performance_plots.ipynb b/project/01_2_performance_plots.ipynb index 78db92131..a4870a184 100644 --- a/project/01_2_performance_plots.ipynb +++ b/project/01_2_performance_plots.ipynb @@ -36,7 +36,6 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "import seaborn as sns\n", "\n", "import vaep\n", "import vaep.imputation\n", @@ -105,7 +104,6 @@ "execution_count": null, "id": "e6e91c6b-20d6-402c-9577-a2bfd8ba592e", "metadata": { - "lines_to_next_cell": 2, "tags": [ "parameters" ] @@ -122,7 +120,8 @@ "models: str = 'Median,CF,DAE,VAE' # picked models to compare (comma separated)\n", "sel_models: str = '' # user defined comparison (comma separated)\n", "# Restrict plotting to top N methods for imputation based on error of validation data, maximum 10\n", - "plot_to_n: int = 5" + "plot_to_n: int = 5\n", + "feat_name_display: str = None # display name for feature name (e.g. 'protein group')" ] }, { @@ -187,6 +186,7 @@ "MIN_FREQ = None\n", "MODELS_PASSED = args.models.split(',')\n", "MODELS = MODELS_PASSED.copy()\n", + "FEAT_NAME_DISPLAY = args.feat_name_display\n", "SEL_MODELS = None\n", "if args.sel_models:\n", " SEL_MODELS = args.sel_models.split(',')" @@ -430,6 +430,9 @@ " split='val',\n", " model_keys=MODELS_PASSED,\n", " shared_columns=[TARGET_COL])\n", + "SAMPLE_ID, FEAT_NAME = pred_val.index.names\n", + "if not FEAT_NAME_DISPLAY:\n", + " FEAT_NAME_DISPLAY = FEAT_NAME\n", "pred_val[MODELS]" ] }, @@ -738,6 +741,7 @@ " ],\n", " feat_medians=data.train_X.median(),\n", " ax=ax,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", " palette=TOP_N_COLOR_PALETTE,\n", " metric_name=METRIC,)\n", "ax.set_ylabel(f\"Average error ({METRIC})\")\n", @@ -784,7 +788,7 @@ " model_keys=MODELS_PASSED,\n", " shared_columns=[TARGET_COL])\n", "pred_test = pred_test.join(freq_feat, on=freq_feat.index.name)\n", - "SAMPLE_ID, FEAT_NAME = pred_test.index.names\n", + "\n", "pred_test" ] }, @@ -1103,7 +1107,7 @@ "source": [ "kwargs = dict(rot=90,\n", " flierprops=dict(markersize=1),\n", - " ylabel=f'correlation per {FEAT_NAME}')\n", + " ylabel=f'correlation per {FEAT_NAME_DISPLAY}')\n", "ax = (corr_per_feat_test\n", " .loc[~too_few_obs, TOP_N_ORDER]\n", " .plot\n", @@ -1255,7 +1259,7 @@ "fig, ax = plt.subplots(figsize=(4, 2))\n", "ax = _to_plot.loc[[feature_names.name]].plot.bar(\n", " rot=0,\n", - " ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n", + " ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n", " # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n", " color=COLORS_TO_USE,\n", " ax=ax,\n", @@ -1313,6 +1317,7 @@ " ],\n", " feat_medians=data.train_X.median(),\n", " ax=ax,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", " metric_name=METRIC,\n", " palette=COLORS_TO_USE\n", ")\n", @@ -1326,6 +1331,23 @@ "errors_binned" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "b13ecd37", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "(errors_binned\n", + " .set_index(\n", + " ['model', errors_binned.columns[-1]]\n", + " )\n", + " .loc[ORDER_MODELS[0]]\n", + " .sort_values(by=METRIC))" + ] + }, { "cell_type": "markdown", "id": "26370a1a", @@ -1372,7 +1394,7 @@ " fig, ax = plt.subplots(figsize=(4, 2))\n", " ax = _to_plot.loc[[feature_names.name]].plot.bar(\n", " rot=0,\n", - " ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n", + " ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n", " # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n", " color=vaep.plotting.defaults.assign_colors(\n", " list(k.upper() for k in SEL_MODELS)),\n", @@ -1413,6 +1435,7 @@ " feat_medians=data.train_X.median(),\n", " ax=ax,\n", " metric_name=METRIC,\n", + " feat_name=FEAT_NAME_DISPLAY,\n", " palette=vaep.plotting.defaults.assign_colors(\n", " list(k.upper() for k in SEL_MODELS))\n", " )\n", @@ -1429,23 +1452,6 @@ " display(errors_binned)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "b13ecd37", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "(errors_binned\n", - " .set_index(\n", - " ['model', errors_binned.columns[-1]]\n", - " )\n", - " .loc[ORDER_MODELS[0]]\n", - " .sort_values(by=METRIC))" - ] - }, { "cell_type": "markdown", "id": "549236ca-9e89-47aa-905c-c97a45d4dc2b", diff --git a/project/01_2_performance_plots.py b/project/01_2_performance_plots.py index 0d1526a09..59e626918 100644 --- a/project/01_2_performance_plots.py +++ b/project/01_2_performance_plots.py @@ -36,7 +36,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import seaborn as sns import vaep import vaep.imputation @@ -98,7 +97,7 @@ def build_text(s): sel_models: str = '' # user defined comparison (comma separated) # Restrict plotting to top N methods for imputation based on error of validation data, maximum 10 plot_to_n: int = 5 - +feat_name_display: str = None # display name for feature name (e.g. 'protein group') # %% [markdown] # Some argument transformations @@ -121,6 +120,7 @@ def build_text(s): MIN_FREQ = None MODELS_PASSED = args.models.split(',') MODELS = MODELS_PASSED.copy() +FEAT_NAME_DISPLAY = args.feat_name_display SEL_MODELS = None if args.sel_models: SEL_MODELS = args.sel_models.split(',') @@ -227,6 +227,9 @@ def build_text(s): split='val', model_keys=MODELS_PASSED, shared_columns=[TARGET_COL]) +SAMPLE_ID, FEAT_NAME = pred_val.index.names +if not FEAT_NAME_DISPLAY: + FEAT_NAME_DISPLAY = FEAT_NAME pred_val[MODELS] # %% [markdown] @@ -370,6 +373,7 @@ def build_text(s): ], feat_medians=data.train_X.median(), ax=ax, + feat_name=FEAT_NAME_DISPLAY, palette=TOP_N_COLOR_PALETTE, metric_name=METRIC,) ax.set_ylabel(f"Average error ({METRIC})") @@ -393,7 +397,7 @@ def build_text(s): model_keys=MODELS_PASSED, shared_columns=[TARGET_COL]) pred_test = pred_test.join(freq_feat, on=freq_feat.index.name) -SAMPLE_ID, FEAT_NAME = pred_test.index.names + pred_test # %% [markdown] @@ -553,7 +557,7 @@ def build_text(s): # %% kwargs = dict(rot=90, flierprops=dict(markersize=1), - ylabel=f'correlation per {FEAT_NAME}') + ylabel=f'correlation per {FEAT_NAME_DISPLAY}') ax = (corr_per_feat_test .loc[~too_few_obs, TOP_N_ORDER] .plot @@ -635,7 +639,7 @@ def highlight_min(s, color, tolerence=0.00001): fig, ax = plt.subplots(figsize=(4, 2)) ax = _to_plot.loc[[feature_names.name]].plot.bar( rot=0, - ylabel=f"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)", + ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)", # title=f'performance on test data (based on {n_in_comparison:,} measurements)', color=COLORS_TO_USE, ax=ax, @@ -670,6 +674,7 @@ def highlight_min(s, color, tolerence=0.00001): ], feat_medians=data.train_X.median(), ax=ax, + feat_name=FEAT_NAME_DISPLAY, metric_name=METRIC, palette=COLORS_TO_USE ) @@ -682,6 +687,14 @@ def highlight_min(s, color, tolerence=0.00001): errors_binned.to_csv(fname.with_suffix('.csv')) errors_binned +# %% +(errors_binned + .set_index( + ['model', errors_binned.columns[-1]] + ) + .loc[ORDER_MODELS[0]] + .sort_values(by=METRIC)) + # %% [markdown] # ### Custom model selection @@ -715,7 +728,7 @@ def highlight_min(s, color, tolerence=0.00001): fig, ax = plt.subplots(figsize=(4, 2)) ax = _to_plot.loc[[feature_names.name]].plot.bar( rot=0, - ylabel=f"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)", + ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)", # title=f'performance on test data (based on {n_in_comparison:,} measurements)', color=vaep.plotting.defaults.assign_colors( list(k.upper() for k in SEL_MODELS)), @@ -750,6 +763,7 @@ def highlight_min(s, color, tolerence=0.00001): feat_medians=data.train_X.median(), ax=ax, metric_name=METRIC, + feat_name=FEAT_NAME_DISPLAY, palette=vaep.plotting.defaults.assign_colors( list(k.upper() for k in SEL_MODELS)) ) @@ -765,14 +779,6 @@ def highlight_min(s, color, tolerence=0.00001): # ax.xaxis.set_tick_params(rotation=0) # horizontal display(errors_binned) -# %% -(errors_binned - .set_index( - ['model', errors_binned.columns[-1]] - ) - .loc[ORDER_MODELS[0]] - .sort_values(by=METRIC)) - # %% [markdown] # ### Error by non-decimal number of intensity # diff --git a/vaep/plotting/errors.py b/vaep/plotting/errors.py index 7d6dd790f..5326b9d86 100644 --- a/vaep/plotting/errors.py +++ b/vaep/plotting/errors.py @@ -52,6 +52,7 @@ def plot_errors_by_median(pred: pd.DataFrame, target_col='observed', ax: Axes = None, palette: dict = None, + feat_name: str = None, metric_name: Optional[str] = None, errwidth: float = 1.2) -> tuple[Axes, pd.DataFrame]: # calculate absolute errors @@ -74,9 +75,10 @@ def plot_errors_by_median(pred: pd.DataFrame, errors = errors.join(n_obs, on="bin") - feat_name = feat_medians.index.name - if not feat_name: - feat_name = 'feature' + if feat_name is None: + feat_name = feat_medians.index.name + if not feat_name: + feat_name = 'feature' x_axis_name = f'intensity binned by median of {feat_name}' len_max_bin = len(str(int(errors['bin'].max())))