Skip to content

Commit

Permalink
🎨 annotate some functions, remove tags
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed May 31, 2024
1 parent 115e681 commit 7fc0193
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 25 deletions.
1 change: 0 additions & 1 deletion project/01_0_split_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,6 @@
{
"cell_type": "code",
"execution_count": null,
"id": "34ee6256",
"metadata": {
"lines_to_next_cell": 2,
"tags": [
Expand Down
37 changes: 27 additions & 10 deletions project/02_3_grid_search_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@
"metadata": {},
"outputs": [],
"source": [
"import snakemake\n",
"import logging\n",
"import pathlib\n",
"import pandas as pd\n",
"import plotly.express as px\n",
"\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import plotly.express as px\n",
"import seaborn as sns\n",
"\n",
"import vaep.io\n",
"import vaep.nb\n",
"import vaep.pandas\n",
"import vaep.plotting.plotly as px_vaep\n",
"from vaep.analyzers import compare_predictions\n",
"import vaep.utils\n",
"from vaep import sampling\n",
"from vaep.analyzers import compare_predictions\n",
"from vaep.io import datasplits\n",
"import vaep.utils\n",
"import vaep.pandas\n",
"import vaep.io\n",
"import vaep.nb\n",
"\n",
"matplotlib.rcParams['figure.figsize'] = [12.0, 6.0]\n",
"\n",
"\n",
Expand Down Expand Up @@ -96,15 +99,18 @@
"cell_type": "code",
"execution_count": null,
"id": "8f0497b1-5f91-45e9-a3e1-88de08b928a9",
"metadata": {},
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"# not robust\n",
"try:\n",
" ORDER = {'model': snakemake.params.models}\n",
" FILE_FORMAT = snakemake.params.file_format\n",
"except AttributeError:\n",
" ORDER = {'model': ['CF', 'DAE', 'VAE']}\n",
"FILE_FORMAT = snakemake.params.file_format"
" FILE_FORMAT = 'csv'"
]
},
{
Expand Down Expand Up @@ -607,6 +613,16 @@
"id": "f8190d51-c4db-4aae-8b91-11641958a0f8",
"metadata": {},
"outputs": [],
"source": [
"view = metrics_long[[\"model\", \"n_params\", \"data_split\", \"metric_name\", \"metric_value\"]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f98b49d8",
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams['figure.figsize'] = (7, 4)\n",
"plt.rcParams['lines.linewidth'] = 2\n",
Expand All @@ -616,7 +632,7 @@
"col_order = ('valid_fake_na', 'test_fake_na')\n",
"row_order = ('MAE', 'MSE')\n",
"fg = sns.relplot(\n",
" data=metrics_long,\n",
" data=view,\n",
" x='n_params',\n",
" y='metric_value',\n",
" col=\"data_split\",\n",
Expand Down Expand Up @@ -652,6 +668,7 @@
"fname\n",
"fname = FOLDER / \"hyperpar_results_by_parameters_val+test.pdf\"\n",
"files_out[fname.name] = fname.as_posix()\n",
"view.to_excel(fname.with_suffix('.xlsx'))\n",
"fg.savefig(fname)\n",
"fg.savefig(fname.with_suffix('.png'), dpi=300)"
]
Expand Down
27 changes: 18 additions & 9 deletions project/02_3_grid_search_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,25 @@
# # Analyis of grid hyperparameter search

# %%
import snakemake
import logging
import pathlib
import pandas as pd
import plotly.express as px

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import seaborn as sns

import vaep.io
import vaep.nb
import vaep.pandas
import vaep.plotting.plotly as px_vaep
from vaep.analyzers import compare_predictions
import vaep.utils
from vaep import sampling
from vaep.analyzers import compare_predictions
from vaep.io import datasplits
import vaep.utils
import vaep.pandas
import vaep.io
import vaep.nb

matplotlib.rcParams['figure.figsize'] = [12.0, 6.0]


Expand Down Expand Up @@ -66,9 +69,11 @@
# not robust
try:
ORDER = {'model': snakemake.params.models}
FILE_FORMAT = snakemake.params.file_format
except AttributeError:
ORDER = {'model': ['CF', 'DAE', 'VAE']}
FILE_FORMAT = snakemake.params.file_format
FILE_FORMAT = 'csv'


# %%
path_metrics = pathlib.Path(metrics_csv)
Expand Down Expand Up @@ -318,6 +323,9 @@
hover_data['data_split'] = True
hover_data['metric_value'] = ':.4f'

# %%
view = metrics_long[["model", "n_params", "data_split", "metric_name", "metric_value"]]

# %%
plt.rcParams['figure.figsize'] = (7, 4)
plt.rcParams['lines.linewidth'] = 2
Expand All @@ -327,7 +335,7 @@
col_order = ('valid_fake_na', 'test_fake_na')
row_order = ('MAE', 'MSE')
fg = sns.relplot(
data=metrics_long,
data=view,
x='n_params',
y='metric_value',
col="data_split",
Expand Down Expand Up @@ -363,6 +371,7 @@
fname
fname = FOLDER / "hyperpar_results_by_parameters_val+test.pdf"
files_out[fname.name] = fname.as_posix()
view.to_excel(fname.with_suffix('.xlsx'))
fg.savefig(fname)
fg.savefig(fname.with_suffix('.png'), dpi=300)

Expand Down
7 changes: 4 additions & 3 deletions vaep/analyzers/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,10 @@ def get_consecutive_data_indices(df, n_samples):
return df.loc[index[start_sample:start_sample + n_samples]]


def corr_lower_triangle(df):
"""Compute the correlation matrix, returning only unique values."""
corr_df = df.corr()
def corr_lower_triangle(df, **kwargs):
"""Compute the correlation matrix, returning only unique values.
"""
corr_df = df.corr(**kwargs)
lower_triangle = pd.DataFrame(
np.tril(np.ones(corr_df.shape), -1)).astype(bool)
lower_triangle.index, lower_triangle.columns = corr_df.index, corr_df.columns
Expand Down
7 changes: 6 additions & 1 deletion vaep/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

NUMPY_ONE = np.int64(1)

__all__ = ['ae', 'analysis', 'collab', 'vae', 'plot_loss', 'plot_training_losses',
'calc_net_weight_count', 'RecorderDump', 'split_prediction_by_mask',
'compare_indices', 'collect_metrics', 'calculte_metrics',
'Metrics', 'get_df_from_nested_dict']


def plot_loss(recorder: learner.Recorder,
Expand Down Expand Up @@ -312,7 +316,8 @@ def __repr__(self):


def get_df_from_nested_dict(nested_dict,
column_levels=('data_split', 'model', 'metric_name'),
column_levels=(
'data_split', 'model', 'metric_name'),
row_name='subset'):
metrics = {}
for k, run_metrics in nested_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion vaep/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_lower_whiskers(df: pd.DataFrame, factor: float = 1.5) -> pd.Series:
return ret


def get_counts_per_bin(df: pd.DataFrame, bins: range, columns: Optional[List[str]] = None):
def get_counts_per_bin(df: pd.DataFrame, bins: range, columns: Optional[List[str]] = None) -> pd.DataFrame:
"""Return counts per bin for selected columns in DataFrame."""
counts_per_bin = dict()
if columns is None:
Expand Down

0 comments on commit 7fc0193

Please sign in to comment.