Skip to content

Commit

Permalink
✨ update error extraction (seaborn 0.12.2 and 0.13.2)
Browse files Browse the repository at this point in the history
- check that everything works with 0.12.2 before testing switching to 0.13.2
  • Loading branch information
enryH committed Aug 30, 2024
1 parent cfbb8e6 commit c673fb9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 48 deletions.
76 changes: 29 additions & 47 deletions pimmslearn/plotting/errors.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Plot errors based on DataFrame with model predictions."""
from __future__ import annotations

import itertools
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from seaborn.categorical import _BarPlotter
from seaborn.categorical import EstimateAggregator


import pimmslearn.pandas.calc_errors

Expand Down Expand Up @@ -109,52 +107,36 @@ def plot_errors_by_median(pred: pd.DataFrame,
return ax, errors


def get_data_for_errors_by_median(errors: pd.DataFrame, feat_name, metric_name, seed=None):
"""Extract Bars with confidence intervals from seaborn plot.
Confident intervals are calculated with bootstrapping (sampling the mean).
Relies on internal seaborn class. only used for reporting of source data in the paper.
def get_data_for_errors_by_median(errors: pd.DataFrame,
feat_name: str,
metric_name: str,
model_column: str = 'model',
seed: int = 42) -> pd.DataFrame:
"""Extract Bars with confidence intervals from seaborn plot for seaborn 0.13 and above.
Confident intervals are calculated with bootstrapping(sampling the mean).
Parameters
----------
errors: pd.DataFrame
DataFrame created by `plot_errors_by_median` function
feat_name: str
feature name assigned(was transformed to 'intensity binned by median of {feat_name}')
metric_name: str
Metric used to calculate errors(MAE, MSE, etc) of intensities in bin
model_column: str
model_column in errors, defining model names
"""
x_axis_name = f'intensity binned by median of {feat_name}'
aggregator = EstimateAggregator("mean", ("ci", 95), n_boot=1_000, seed=seed)
# ! need to iterate over all models myself using groupby
ret = (errors
.groupby(by=[x_axis_name, model_column,], observed=True)
[[x_axis_name, model_column, metric_name]]
.apply(lambda df: aggregator(df, metric_name))
.reset_index())
ret.columns = ["bin", model_column, "mean", "ci_low", "ci_high"]
return ret

plotter = _BarPlotter(data=errors, x=x_axis_name, y=metric_name, hue='model',
order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=seed,
orient=None, color=None, palette=None, saturation=.75, width=.8,
errcolor=".26", errwidth=None, capsize=None, dodge=True)
ax = plt.gca()
plotter.plot(ax, {})
plt.close(ax.get_figure())
mean, cf_interval = plotter.statistic.flatten(), plotter.confint.reshape(-1, 2)
plotted = pd.DataFrame(np.concatenate((mean.reshape(-1, 1), cf_interval), axis=1), columns=[
'mean', 'ci_low', 'ci_high'])
_index = pd.DataFrame(list(itertools.product(
(_l.get_text() for _l in ax.get_xticklabels()), # bins x-axis
(_l.get_text() for _l in ax.get_legend().get_texts()), # models legend
)
), columns=['bin', 'model'])
plotted = pd.concat([_index, plotted], axis=1)
return plotted


# def get_data_for_errors_by_median_v2(errors: pd.DataFrame, feat_name, metric_name):
# from seaborn._statistics import (
# EstimateAggregator,
# WeightedAggregator,
# )
# from seaborn.categorical import _CategoricalAggPlotter, WeightedAggregator, EstimateAggregator
# p = _CategoricalAggPlotter(
# data=data,
# variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
# order=order,
# orient=orient,
# color=color,
# legend=legend,
# )

# agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
# aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
# err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D)


def plot_rolling_error(errors: pd.DataFrame, metric_name: str, window: int = 200,
Expand Down
3 changes: 2 additions & 1 deletion tests/plotting/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def expected_plotted():
plotted_path = file_dir / 'expected_plotted.csv'
# ! Windows reads in new line in string characters as '\r\n'
df = pd.read_csv(plotted_path, sep=',', index_col=0)
df["bin"] = df["bin"].str.replace('\r\n', '\n')
df["bin"] = df["bin"].str.replace('\r\n', '\n').astype('category')
df = df.sort_values(by=['bin', 'model']).reset_index(drop=True)
return df


Expand Down

0 comments on commit c673fb9

Please sign in to comment.