Skip to content

Commit

Permalink
🎨 formatting and plot sizes adjusted
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Feb 2, 2024
1 parent 0b761d2 commit d0aec7e
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 77 deletions.
3 changes: 2 additions & 1 deletion project/00_5_training_data_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.5
# jupytext_version: 1.15.0
# kernelspec:
# display_name: vaep
# language: python
Expand Down Expand Up @@ -179,6 +179,7 @@ def get_dynamic_range(min_max):
f"{_levels_dropped}")
# allows overwriting of index name, also to None
data.columns.name = COL_INDEX_NAME
data


# %% [markdown]
Expand Down
22 changes: 12 additions & 10 deletions project/01_0_split_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
"\n",
"\n",
"pd.options.display.max_columns = 32\n",
"plt.rcParams['figure.figsize'] = [4, 2]\n",
"plt.rcParams['figure.figsize'] = [3, 2]\n",
"\n",
"vaep.plotting.make_large_descriptors(6)\n",
"vaep.plotting.make_large_descriptors(7)\n",
"\n",
"figures = {} # collection of ax or figures\n",
"dumps = {} # collection of data dumps"
Expand Down Expand Up @@ -820,12 +820,14 @@
"metadata": {},
"outputs": [],
"source": [
"ax = df_w_date.boxplot(rot=80,\n",
" figsize=(8, 3),\n",
" fontsize=6,\n",
" showfliers=False,\n",
" showcaps=False\n",
" )\n",
"ax = df_w_date.plot.box(rot=80,\n",
" figsize=(7, 3),\n",
" fontsize=7,\n",
" showfliers=False,\n",
" showcaps=False,\n",
" boxprops=dict(linewidth=.4, color='darkblue'),\n",
" flierprops=dict(markersize=.4, color='lightblue'),\n",
" )\n",
"_ = vaep.plotting.select_xticks(ax)\n",
"fig = ax.get_figure()\n",
"fname = params.out_figures / f'0_{group}_median_boxplot'\n",
Expand Down Expand Up @@ -1004,7 +1006,7 @@
"group = 2\n",
"! move parameter checks to start of script\n",
"if 0.0 <= params.frac_mnar <= 1.0:\n",
" fig, axes = plt.subplots(1, 2, figsize=(8, 2))\n",
" fig, axes = plt.subplots(1, 2, figsize=(6, 2))\n",
" quantile_frac = df_long.quantile(params.frac_non_train)\n",
" rng = np.random.default_rng(params.random_state)\n",
" threshold = pd.Series(rng.normal(loc=float(quantile_frac),\n",
Expand Down Expand Up @@ -1404,7 +1406,7 @@
"medians = medians.join(feat_with_median, on='median_floor')\n",
"medians = medians.apply(lambda s: \"{:02,d} (N={:3,d})\".format(*s), axis=1)\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 2))\n",
"fig, ax = plt.subplots(figsize=(6, 2))\n",
"s = 1\n",
"s_axes = pd.DataFrame({'medians': medians,\n",
" 'validation split': splits.val_y.notna().sum(),\n",
Expand Down
22 changes: 12 additions & 10 deletions project/01_0_split_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def add_meta_data(df: pd.DataFrame, df_meta: pd.DataFrame):


pd.options.display.max_columns = 32
plt.rcParams['figure.figsize'] = [4, 2]
plt.rcParams['figure.figsize'] = [3, 2]

vaep.plotting.make_large_descriptors(6)
vaep.plotting.make_large_descriptors(7)

figures = {} # collection of ax or figures
dumps = {} # collection of data dumps
Expand Down Expand Up @@ -507,12 +507,14 @@ def join_as_str(seq):
df_w_date

# %%
ax = df_w_date.boxplot(rot=80,
figsize=(8, 3),
fontsize=6,
showfliers=False,
showcaps=False
)
ax = df_w_date.plot.box(rot=80,
figsize=(7, 3),
fontsize=7,
showfliers=False,
showcaps=False,
boxprops=dict(linewidth=.4, color='darkblue'),
flierprops=dict(markersize=.4, color='lightblue'),
)
_ = vaep.plotting.select_xticks(ax)
fig = ax.get_figure()
fname = params.out_figures / f'0_{group}_median_boxplot'
Expand Down Expand Up @@ -612,7 +614,7 @@ def join_as_str(seq):
group = 2
# ! move parameter checks to start of script
if 0.0 <= params.frac_mnar <= 1.0:
fig, axes = plt.subplots(1, 2, figsize=(8, 2))
fig, axes = plt.subplots(1, 2, figsize=(6, 2))
quantile_frac = df_long.quantile(params.frac_non_train)
rng = np.random.default_rng(params.random_state)
threshold = pd.Series(rng.normal(loc=float(quantile_frac),
Expand Down Expand Up @@ -890,7 +892,7 @@ def join_as_str(seq):
medians = medians.join(feat_with_median, on='median_floor')
medians = medians.apply(lambda s: "{:02,d} (N={:3,d})".format(*s), axis=1)

fig, ax = plt.subplots(figsize=(8, 2))
fig, ax = plt.subplots(figsize=(6, 2))
s = 1
s_axes = pd.DataFrame({'medians': medians,
'validation split': splits.val_y.notna().sum(),
Expand Down
13 changes: 7 additions & 6 deletions project/01_2_performance_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
"pd.options.display.min_rows = 10\n",
"pd.options.display.max_colwidth = 100\n",
"\n",
"plt.rcParams.update({'figure.figsize': (4, 2)})\n",
"vaep.plotting.make_large_descriptors(6)\n",
"plt.rcParams.update({'figure.figsize': (3, 2)})\n",
"vaep.plotting.make_large_descriptors(7)\n",
"\n",
"logger = vaep.logging.setup_nb_logger()\n",
"logging.getLogger('fontTools').setLevel(logging.WARNING)\n",
Expand Down Expand Up @@ -1256,16 +1256,17 @@
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(4, 2))\n",
"\n",
"fig, ax = plt.subplots(figsize=(6, 2))\n",
"ax = _to_plot.loc[[feature_names.name]].plot.bar(\n",
" rot=0,\n",
" ylabel=f\"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)\",\n",
" # title=f'performance on test data (based on {n_in_comparison:,} measurements)',\n",
" color=COLORS_TO_USE,\n",
" ax=ax,\n",
" width=.8)\n",
"ax = vaep.plotting.add_height_to_barplot(ax, size=5)\n",
"ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=5)\n",
"ax = vaep.plotting.add_height_to_barplot(ax, size=7)\n",
"ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc[\"text\"], size=7)\n",
"ax.set_xticklabels([])\n",
"fname = args.out_figures / f'2_{group}_performance_test.pdf'\n",
"figures[fname.stem] = fname\n",
Expand Down Expand Up @@ -1309,7 +1310,7 @@
"outputs": [],
"source": [
"vaep.plotting.make_large_descriptors(7)\n",
"fig, ax = plt.subplots(figsize=(8, 2))\n",
"fig, ax = plt.subplots(figsize=(6, 2))\n",
"\n",
"ax, errors_binned = vaep.plotting.errors.plot_errors_by_median(\n",
" pred=pred_test[\n",
Expand Down
13 changes: 7 additions & 6 deletions project/01_2_performance_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
pd.options.display.min_rows = 10
pd.options.display.max_colwidth = 100

plt.rcParams.update({'figure.figsize': (4, 2)})
vaep.plotting.make_large_descriptors(6)
plt.rcParams.update({'figure.figsize': (3, 2)})
vaep.plotting.make_large_descriptors(7)

logger = vaep.logging.setup_nb_logger()
logging.getLogger('fontTools').setLevel(logging.WARNING)
Expand Down Expand Up @@ -636,16 +636,17 @@ def highlight_min(s, color, tolerence=0.00001):


# %%
fig, ax = plt.subplots(figsize=(4, 2))

fig, ax = plt.subplots(figsize=(6, 2))
ax = _to_plot.loc[[feature_names.name]].plot.bar(
rot=0,
ylabel=f"{METRIC} for {FEAT_NAME_DISPLAY} ({n_in_comparison:,} intensities)",
# title=f'performance on test data (based on {n_in_comparison:,} measurements)',
color=COLORS_TO_USE,
ax=ax,
width=.8)
ax = vaep.plotting.add_height_to_barplot(ax, size=5)
ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=5)
ax = vaep.plotting.add_height_to_barplot(ax, size=7)
ax = vaep.plotting.add_text_to_barplot(ax, _to_plot.loc["text"], size=7)
ax.set_xticklabels([])
fname = args.out_figures / f'2_{group}_performance_test.pdf'
figures[fname.stem] = fname
Expand All @@ -666,7 +667,7 @@ def highlight_min(s, color, tolerence=0.00001):

# %%
vaep.plotting.make_large_descriptors(7)
fig, ax = plt.subplots(figsize=(8, 2))
fig, ax = plt.subplots(figsize=(6, 2))

ax, errors_binned = vaep.plotting.errors.plot_errors_by_median(
pred=pred_test[
Expand Down
64 changes: 40 additions & 24 deletions project/02_3_grid_search_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import pathlib\n",
"import pandas as pd\n",
"import plotly.express as px\n",
Expand All @@ -38,7 +39,8 @@
"pd.options.display.max_rows = 100\n",
"pd.options.display.multi_sparse = False\n",
"\n",
"logger = vaep.logging.setup_nb_logger()"
"logger = vaep.logging.setup_nb_logger()\n",
"logging.getLogger('fontTools').setLevel(logging.WARNING)"
]
},
{
Expand Down Expand Up @@ -603,15 +605,13 @@
"cell_type": "code",
"execution_count": null,
"id": "f8190d51-c4db-4aae-8b91-11641958a0f8",
"metadata": {
"lines_to_next_cell": 1
},
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams['figure.figsize'] = (8, 4)\n",
"plt.rcParams['figure.figsize'] = (7, 4)\n",
"plt.rcParams['lines.linewidth'] = 2\n",
"plt.rcParams['lines.markersize'] = 3\n",
"vaep.plotting.make_large_descriptors(5)\n",
"vaep.plotting.make_large_descriptors(7)\n",
"\n",
"col_order = ('valid_fake_na', 'test_fake_na')\n",
"row_order = ('MAE', 'MSE')\n",
Expand All @@ -628,24 +628,29 @@
" palette=vaep.plotting.defaults.color_model_mapping,\n",
" height=2,\n",
" aspect=1.8,\n",
" kind=\"scatter\"\n",
" kind=\"scatter\",\n",
")\n",
"fg.fig.get_size_inches()\n",
"\n",
"(ax_00, ax_01), (ax_10, ax_11) = fg.axes\n",
"ax_00.set_ylabel(row_order[0])\n",
"ax_10.set_ylabel(row_order[1])\n",
"_ = ax_00.set_title('validation data') # col_order[0]\n",
"_ = ax_01.set_title('test data') # col_order[1]\n",
"ax_10.set_xticklabels(ax_10.get_xticklabels(),\n",
" rotation=45,\n",
" horizontalalignment='right')\n",
"ax_10.set_xlabel('number of parameters') # n_params\n",
"ax_11.set_xticklabels(ax_11.get_xticklabels(),\n",
" rotation=45,\n",
" horizontalalignment='right')\n",
"ax_11.set_xlabel('number of parameters')\n",
"ax_10.xaxis.set_major_formatter(\"{x:,.0f}\")\n",
"ax_11.xaxis.set_major_formatter(\"{x:,.0f}\")\n",
"_ = ax_10.set_title('')\n",
"_ = ax_11.set_title('')\n",
"fg.tight_layout()\n",
"fname\n",
"fname = FOLDER / f\"hyperpar_results_by_parameters_val+test.pdf\"\n",
"fname = FOLDER / \"hyperpar_results_by_parameters_val+test.pdf\"\n",
"files_out[fname.name] = fname.as_posix()\n",
"fg.savefig(fname)\n",
"fg.savefig(fname.with_suffix('.png'), dpi=300)"
Expand All @@ -658,6 +663,8 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def plot_by_params(data_split: str = '', subset: str = ''):\n",
" selected = metrics_long\n",
" if data_split:\n",
Expand Down Expand Up @@ -1107,13 +1114,16 @@
"outputs": [],
"source": [
"mask = errors_smoothed[freq_feat.name] >= FREQ_MIN\n",
"ax = errors_smoothed.loc[mask].rename_axis('', axis=1).plot(x=freq_feat.name,\n",
" xlabel='freq/feature prevalence (across samples)',\n",
" ylabel=f'rolling average error ({METRIC})',\n",
" xlim=(\n",
" FREQ_MIN, errors_smoothed[freq_feat.name].max()),\n",
" # title=f'Rolling average error by feature frequency {msg_annotation}'\n",
" )\n",
"ax = (errors_smoothed\n",
" .loc[mask]\n",
" .rename_axis('', axis=1)\n",
" .plot(x=freq_feat.name,\n",
" xlabel='freq/feature prevalence (across samples)',\n",
" ylabel=f'rolling average error ({METRIC})',\n",
" xlim=(\n",
" FREQ_MIN, errors_smoothed[freq_feat.name].max()),\n",
" # title=f'Rolling average error by feature frequency {msg_annotation}'\n",
" ))\n",
"\n",
"msg_annotation = f\"(Latend dim: {min_latent}, No. of feat: {M_feat}, window_size: {window_size})\"\n",
"print(msg_annotation)\n",
Expand Down Expand Up @@ -1155,7 +1165,10 @@
},
"outputs": [],
"source": [
"fig = px_vaep.line(errors_smoothed_long.loc[errors_smoothed_long[freq_feat.name] >= FREQ_MIN].join(n_obs_error_is_based_on).sort_values(by='freq'),\n",
"fig = px_vaep.line((errors_smoothed_long\n",
" .loc[errors_smoothed_long[freq_feat.name] >= FREQ_MIN]\n",
" .join(n_obs_error_is_based_on)\n",
" .sort_values(by='freq')),\n",
" x=freq_feat.name,\n",
" color='model',\n",
" y='rolling error average',\n",
Expand Down Expand Up @@ -1424,13 +1437,16 @@
"errors_smoothed[order_models] = errors[order_models].rolling(\n",
" window=window_size, min_periods=1).mean()\n",
"mask = errors_smoothed[freq_feat.name] >= FREQ_MIN\n",
"ax = errors_smoothed.loc[mask].rename_axis('', axis=1).plot(x=freq_feat.name,\n",
" ylabel='rolling error average',\n",
" xlabel='freq/feature prevalence (across samples)',\n",
" xlim=(\n",
" FREQ_MIN, freq_feat.max()),\n",
" # title=f'Rolling average error by feature frequency {msg_annotation}'\n",
" )\n",
"ax = (errors_smoothed\n",
" .loc[mask]\n",
" .rename_axis('', axis=1)\n",
" .plot(x=freq_feat.name,\n",
" ylabel='rolling error average',\n",
" xlabel='freq/feature prevalence (across samples)',\n",
" xlim=(\n",
" FREQ_MIN, freq_feat.max()),\n",
" # title=f'Rolling average error by feature frequency {msg_annotation}'\n",
" ))\n",
"\n",
"vaep.savefig(\n",
" ax.get_figure(),\n",
Expand Down
Loading

0 comments on commit d0aec7e

Please sign in to comment.