Skip to content

Commit

Permalink
🎨 allow custom display name of feat
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Nov 28, 2023
1 parent 49ee8e5 commit e49c1eb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
54 changes: 30 additions & 24 deletions project/01_2_performance_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"import vaep\n",
"import vaep.imputation\n",
Expand Down Expand Up @@ -105,7 +104,6 @@
"execution_count": null,
"id": "e6e91c6b-20d6-402c-9577-a2bfd8ba592e",
"metadata": {
"lines_to_next_cell": 2,
"tags": [
"parameters"
]
Expand All @@ -122,7 +120,8 @@
"models: str = 'Median,CF,DAE,VAE' # picked models to compare (comma separated)\n",
"sel_models: str = '' # user defined comparison (comma separated)\n",
"# Restrict plotting to top N methods for imputation based on error of validation data, maximum 10\n",
"plot_to_n: int = 5"
"plot_to_n: int = 5\n",
"feat_name_display: str = None # display name for feature name (e.g. 'protein group')"
]
},
{
Expand Down Expand Up @@ -187,6 +186,7 @@
"MIN_FREQ = None\n",
"MODELS_PASSED = args.models.split(',')\n",
"MODELS = MODELS_PASSED.copy()\n",
"FEAT_NAME_DISPLAY = args.feat_name_display\n",
"SEL_MODELS = None\n",
"if args.sel_models:\n",
" SEL_MODELS = args.sel_models.split(',')"
Expand Down Expand Up @@ -430,6 +430,9 @@
" split='val',\n",
" model_keys=MODELS_PASSED,\n",
" shared_columns=[TARGET_COL])\n",
"SAMPLE_ID, FEAT_NAME = pred_val.index.names\n",
"if not FEAT_NAME_DISPLAY:\n",
" FEAT_NAME_DISPLAY = FEAT_NAME\n",
"pred_val[MODELS]"
]
},
Expand Down Expand Up @@ -738,6 +741,7 @@
" ],\n",
" feat_medians=data.train_X.median(),\n",
" ax=ax,\n",
" feat_name=FEAT_NAME_DISPLAY,\n",
" palette=TOP_N_COLOR_PALETTE,\n",
" metric_name=METRIC,)\n",
"ax.set_ylabel(f\"Average error ({METRIC})\")\n",
Expand Down Expand Up @@ -784,7 +788,7 @@
" model_keys=MODELS_PASSED,\n",
" shared_columns=[TARGET_COL])\n",
"pred_test = pred_test.join(freq_feat, on=freq_feat.index.name)\n",
"SAMPLE_ID, FEAT_NAME = pred_test.index.names\n",
"\n",
"pred_test"
]
},
Expand Down Expand Up @@ -1103,7 +1107,7 @@
"source": [
"kwargs = dict(rot=90,\n",
" flierprops=dict(markersize=1),\n",
" ylabel=f'correlation per {FEAT_NAME}')\n",
" ylabel=f'correlation per {FEAT_NAME_DISPLAY}')\n",
"ax = (corr_per_feat_test\n",
" .loc[~too_few_obs, TOP_N_ORDER]\n",
" .plot\n",
Expand Down Expand Up @@ -1255,7 +1259,7 @@
"fig, ax = plt.subplots(figsize=(4, 2))\n",
"ax = _to_plot.loc[[feature_names.name]].plot.bar(\n",
" rot=0,\n",
" ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n",
" ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=COLORS_TO_USE,\n",
" ax=ax,\n",
Expand Down Expand Up @@ -1313,6 +1317,7 @@
" ],\n",
" feat_medians=data.train_X.median(),\n",
" ax=ax,\n",
" feat_name=FEAT_NAME_DISPLAY,\n",
" metric_name=METRIC,\n",
" palette=COLORS_TO_USE\n",
")\n",
Expand All @@ -1326,6 +1331,23 @@
"errors_binned"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b13ecd37",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"(errors_binned\n",
" .set_index(\n",
" ['model', errors_binned.columns[-1]]\n",
" )\n",
" .loc[ORDER_MODELS[0]]\n",
" .sort_values(by=METRIC))"
]
},
{
"cell_type": "markdown",
"id": "26370a1a",
Expand Down Expand Up @@ -1372,7 +1394,7 @@
" fig, ax = plt.subplots(figsize=(4, 2))\n",
" ax = _to_plot.loc[[feature_names.name]].plot.bar(\n",
" rot=0,\n",
" ylabel=f\"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)\",\n",
" ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=vaep.plotting.defaults.assign_colors(\n",
" list(k.upper() for k in SEL_MODELS)),\n",
Expand Down Expand Up @@ -1413,6 +1435,7 @@
" feat_medians=data.train_X.median(),\n",
" ax=ax,\n",
" metric_name=METRIC,\n",
" feat_name=FEAT_NAME_DISPLAY,\n",
" palette=vaep.plotting.defaults.assign_colors(\n",
" list(k.upper() for k in SEL_MODELS))\n",
" )\n",
Expand All @@ -1429,23 +1452,6 @@
" display(errors_binned)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b13ecd37",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"(errors_binned\n",
" .set_index(\n",
" ['model', errors_binned.columns[-1]]\n",
" )\n",
" .loc[ORDER_MODELS[0]]\n",
" .sort_values(by=METRIC))"
]
},
{
"cell_type": "markdown",
"id": "549236ca-9e89-47aa-905c-c97a45d4dc2b",
Expand Down
34 changes: 20 additions & 14 deletions project/01_2_performance_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import vaep
import vaep.imputation
Expand Down Expand Up @@ -98,7 +97,7 @@ def build_text(s):
sel_models: str = '' # user defined comparison (comma separated)
# Restrict plotting to top N methods for imputation based on error of validation data, maximum 10
plot_to_n: int = 5

feat_name_display: str = None # display name for feature name (e.g. 'protein group')

# %% [markdown]
# Some argument transformations
Expand All @@ -121,6 +120,7 @@ def build_text(s):
MIN_FREQ = None
MODELS_PASSED = args.models.split(',')
MODELS = MODELS_PASSED.copy()
FEAT_NAME_DISPLAY = args.feat_name_display
SEL_MODELS = None
if args.sel_models:
SEL_MODELS = args.sel_models.split(',')
Expand Down Expand Up @@ -227,6 +227,9 @@ def build_text(s):
split='val',
model_keys=MODELS_PASSED,
shared_columns=[TARGET_COL])
SAMPLE_ID, FEAT_NAME = pred_val.index.names
if not FEAT_NAME_DISPLAY:
FEAT_NAME_DISPLAY = FEAT_NAME
pred_val[MODELS]

# %% [markdown]
Expand Down Expand Up @@ -370,6 +373,7 @@ def build_text(s):
],
feat_medians=data.train_X.median(),
ax=ax,
feat_name=FEAT_NAME_DISPLAY,
palette=TOP_N_COLOR_PALETTE,
metric_name=METRIC,)
ax.set_ylabel(f"Average error ({METRIC})")
Expand All @@ -393,7 +397,7 @@ def build_text(s):
model_keys=MODELS_PASSED,
shared_columns=[TARGET_COL])
pred_test = pred_test.join(freq_feat, on=freq_feat.index.name)
SAMPLE_ID, FEAT_NAME = pred_test.index.names

pred_test

# %% [markdown]
Expand Down Expand Up @@ -553,7 +557,7 @@ def build_text(s):
# %%
kwargs = dict(rot=90,
flierprops=dict(markersize=1),
ylabel=f'correlation per {FEAT_NAME}')
ylabel=f'correlation per {FEAT_NAME_DISPLAY}')
ax = (corr_per_feat_test
.loc[~too_few_obs, TOP_N_ORDER]
.plot
Expand Down Expand Up @@ -635,7 +639,7 @@ def highlight_min(s, color, tolerence=0.00001):
fig, ax = plt.subplots(figsize=(4, 2))
ax = _to_plot.loc[[feature_names.name]].plot.bar(
rot=0,
ylabel=f"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)",
ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)",
# title=f'performance on test data (based on {n_in_comparison:,} measurements)',
color=COLORS_TO_USE,
ax=ax,
Expand Down Expand Up @@ -670,6 +674,7 @@ def highlight_min(s, color, tolerence=0.00001):
],
feat_medians=data.train_X.median(),
ax=ax,
feat_name=FEAT_NAME_DISPLAY,
metric_name=METRIC,
palette=COLORS_TO_USE
)
Expand All @@ -682,6 +687,14 @@ def highlight_min(s, color, tolerence=0.00001):
errors_binned.to_csv(fname.with_suffix('.csv'))
errors_binned

# %%
(errors_binned
.set_index(
['model', errors_binned.columns[-1]]
)
.loc[ORDER_MODELS[0]]
.sort_values(by=METRIC))

# %% [markdown]
# ### Custom model selection

Expand Down Expand Up @@ -715,7 +728,7 @@ def highlight_min(s, color, tolerence=0.00001):
fig, ax = plt.subplots(figsize=(4, 2))
ax = _to_plot.loc[[feature_names.name]].plot.bar(
rot=0,
ylabel=f"{METRIC} for {feature_names.name} ({n_in_comparison:,} intensities)",
ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)",
# title=f'performance on test data (based on {n_in_comparison:,} measurements)',
color=vaep.plotting.defaults.assign_colors(
list(k.upper() for k in SEL_MODELS)),
Expand Down Expand Up @@ -750,6 +763,7 @@ def highlight_min(s, color, tolerence=0.00001):
feat_medians=data.train_X.median(),
ax=ax,
metric_name=METRIC,
feat_name=FEAT_NAME_DISPLAY,
palette=vaep.plotting.defaults.assign_colors(
list(k.upper() for k in SEL_MODELS))
)
Expand All @@ -765,14 +779,6 @@ def highlight_min(s, color, tolerence=0.00001):
# ax.xaxis.set_tick_params(rotation=0) # horizontal
display(errors_binned)

# %%
(errors_binned
.set_index(
['model', errors_binned.columns[-1]]
)
.loc[ORDER_MODELS[0]]
.sort_values(by=METRIC))

# %% [markdown]
# ### Error by non-decimal number of intensity
#
Expand Down
8 changes: 5 additions & 3 deletions vaep/plotting/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def plot_errors_by_median(pred: pd.DataFrame,
target_col='observed',
ax: Axes = None,
palette: dict = None,
feat_name: str = None,
metric_name: Optional[str] = None,
errwidth: float = 1.2) -> tuple[Axes, pd.DataFrame]:
# calculate absolute errors
Expand All @@ -74,9 +75,10 @@ def plot_errors_by_median(pred: pd.DataFrame,

errors = errors.join(n_obs, on="bin")

feat_name = feat_medians.index.name
if not feat_name:
feat_name = 'feature'
if feat_name is None:
feat_name = feat_medians.index.name
if not feat_name:
feat_name = 'feature'

x_axis_name = f'intensity binned by median of {feat_name}'
len_max_bin = len(str(int(errors['bin'].max())))
Expand Down

0 comments on commit e49c1eb

Please sign in to comment.