Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ update error extraction (seaborn 0.12.2 and 0.13.2) #80

Merged
merged 4 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/test_pkg_on_colab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ jobs:
image: europe-docker.pkg.dev/colab-images/public/runtime:latest
steps:
- uses: actions/checkout@v4
- name: Install pimms-learn and papermill
- name: Install pimms-learn (from branch) and papermill
if: github.event_name == 'pull_request'
run: |
python3 -m pip install pimms-learn papermill
- name: Install pimms-learn (from PyPI) and papermill
if: github.event_name == 'schedule'
run: |
python3 -m pip install pimms-learn papermill
- name: Run tutorial
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ dependencies:
- pandas>=1
- scipy>=1.6
# plotting
- matplotlib>=3.4,<3.9
- matplotlib>=3.4
- python-kaleido
- plotly
- seaborn<0.13
- seaborn
- pip
# ML
- pytorch #=1.13.1=py3.8_cuda11.7_cudnn8_0
Expand Down
76 changes: 29 additions & 47 deletions pimmslearn/plotting/errors.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Plot errors based on DataFrame with model predictions."""
from __future__ import annotations

import itertools
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from seaborn.categorical import _BarPlotter
from seaborn.categorical import EstimateAggregator


import pimmslearn.pandas.calc_errors

Expand Down Expand Up @@ -109,52 +107,36 @@ def plot_errors_by_median(pred: pd.DataFrame,
return ax, errors


def get_data_for_errors_by_median(errors: pd.DataFrame, feat_name, metric_name, seed=None):
"""Extract Bars with confidence intervals from seaborn plot.
Confident intervals are calculated with bootstrapping (sampling the mean).

Relies on internal seaborn class. only used for reporting of source data in the paper.
def get_data_for_errors_by_median(errors: pd.DataFrame,
feat_name: str,
metric_name: str,
model_column: str = 'model',
seed: int = 42) -> pd.DataFrame:
"""Extract Bars with confidence intervals from seaborn plot for seaborn 0.13 and above.
Confident intervals are calculated with bootstrapping(sampling the mean).

Parameters
----------
errors: pd.DataFrame
DataFrame created by `plot_errors_by_median` function
feat_name: str
feature name assigned(was transformed to 'intensity binned by median of {feat_name}')
metric_name: str
Metric used to calculate errors(MAE, MSE, etc) of intensities in bin
model_column: str
model_column in errors, defining model names
"""
x_axis_name = f'intensity binned by median of {feat_name}'
aggregator = EstimateAggregator("mean", ("ci", 95), n_boot=1_000, seed=seed)
# ! need to iterate over all models myself using groupby
ret = (errors
.groupby(by=[x_axis_name, model_column,], observed=True)
[[x_axis_name, model_column, metric_name]]
.apply(lambda df: aggregator(df, metric_name))
.reset_index())
ret.columns = ["bin", model_column, "mean", "ci_low", "ci_high"]
return ret

plotter = _BarPlotter(data=errors, x=x_axis_name, y=metric_name, hue='model',
order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=seed,
orient=None, color=None, palette=None, saturation=.75, width=.8,
errcolor=".26", errwidth=None, capsize=None, dodge=True)
ax = plt.gca()
plotter.plot(ax, {})
plt.close(ax.get_figure())
mean, cf_interval = plotter.statistic.flatten(), plotter.confint.reshape(-1, 2)
plotted = pd.DataFrame(np.concatenate((mean.reshape(-1, 1), cf_interval), axis=1), columns=[
'mean', 'ci_low', 'ci_high'])
_index = pd.DataFrame(list(itertools.product(
(_l.get_text() for _l in ax.get_xticklabels()), # bins x-axis
(_l.get_text() for _l in ax.get_legend().get_texts()), # models legend
)
), columns=['bin', 'model'])
plotted = pd.concat([_index, plotted], axis=1)
return plotted


# def get_data_for_errors_by_median_v2(errors: pd.DataFrame, feat_name, metric_name):
# from seaborn._statistics import (
# EstimateAggregator,
# WeightedAggregator,
# )
# from seaborn.categorical import _CategoricalAggPlotter, WeightedAggregator, EstimateAggregator
# p = _CategoricalAggPlotter(
# data=data,
# variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
# order=order,
# orient=orient,
# color=color,
# legend=legend,
# )

# agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
# aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
# err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D)


def plot_rolling_error(errors: pd.DataFrame, metric_name: str, window: int = 200,
Expand Down
4 changes: 2 additions & 2 deletions project/04_1_train_pimms_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
" print(f\"Running in colab and pimms-learn ({_v}) is installed.\")\n",
" except metadata.PackageNotFoundError:\n",
" print(\"Install PIMMS...\")\n",
" # !pip install git+https://github.com/RasmussenLab/pimms.git@dev\n",
" # !pip install git+https://github.com/RasmussenLab/pimms.git\n",
" !pip install pimms-learn"
]
},
Expand Down Expand Up @@ -364,7 +364,7 @@
"metadata": {},
"outputs": [],
"source": [
"CollaborativeFilteringTransformer?"
"# # CollaborativeFilteringTransformer?"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions project/04_1_train_pimms_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
print(f"Running in colab and pimms-learn ({_v}) is installed.")
except metadata.PackageNotFoundError:
print("Install PIMMS...")
# # !pip install git+https://github.com/RasmussenLab/pimms.git@dev
# # !pip install git+https://github.com/RasmussenLab/pimms.git
# !pip install pimms-learn

# %% [markdown]
Expand Down Expand Up @@ -167,7 +167,7 @@
# Inspect annotations of the scikit-learn like Transformer:

# %%
# # CollaborativeFilteringTransformer?
# # # CollaborativeFilteringTransformer?

# %% [markdown]
# Let's set up collaborative filtering without a validation or test set, using
Expand Down
4 changes: 2 additions & 2 deletions project/workflow/envs/pimms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ dependencies:
- pandas>=1
- scipy>=1.6
# plotting
- matplotlib<3.9
- matplotlib
- python-kaleido
- plotly
- seaborn<0.13
- seaborn
- pip
# ML
- pytorch #=1.13.1=py3.8_cuda11.7_cudnn8_0
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ classifiers = [
dependencies = [
"njab>=0.0.8",
"numpy",
"matplotlib<3.9",
"matplotlib",
"pandas",
"plotly",
"torch",
"scikit-learn>=1.0",
"scipy",
"seaborn<0.13",
"seaborn",
"fastai",
"omegaconf",
"tqdm",
Expand Down
3 changes: 2 additions & 1 deletion tests/plotting/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def expected_plotted():
plotted_path = file_dir / 'expected_plotted.csv'
# ! Windows reads in new line in string characters as '\r\n'
df = pd.read_csv(plotted_path, sep=',', index_col=0)
df["bin"] = df["bin"].str.replace('\r\n', '\n')
df["bin"] = df["bin"].str.replace('\r\n', '\n').astype('category')
df = df.sort_values(by=['bin', 'model']).reset_index(drop=True)
return df


Expand Down
Loading