From c673fb9e724869aaf9486dbcceff18190c8aa5a2 Mon Sep 17 00:00:00 2001 From: Henry Webel Date: Fri, 30 Aug 2024 08:20:02 +0200 Subject: [PATCH] :sparkles: update error extraction (seaborn 0.12.2 and 0.13.2) - check that everything works with 0.12.2 before testing switching to 0.13.2 --- pimmslearn/plotting/errors.py | 76 +++++++++++++---------------------- tests/plotting/test_errors.py | 3 +- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/pimmslearn/plotting/errors.py b/pimmslearn/plotting/errors.py index ea685589..86ddedf3 100644 --- a/pimmslearn/plotting/errors.py +++ b/pimmslearn/plotting/errors.py @@ -1,15 +1,13 @@ """Plot errors based on DataFrame with model predictions.""" from __future__ import annotations -import itertools from typing import Optional -import matplotlib.pyplot as plt -import numpy as np import pandas as pd import seaborn as sns from matplotlib.axes import Axes -from seaborn.categorical import _BarPlotter +from seaborn.categorical import EstimateAggregator + import pimmslearn.pandas.calc_errors @@ -109,52 +107,36 @@ def plot_errors_by_median(pred: pd.DataFrame, return ax, errors -def get_data_for_errors_by_median(errors: pd.DataFrame, feat_name, metric_name, seed=None): - """Extract Bars with confidence intervals from seaborn plot. - Confident intervals are calculated with bootstrapping (sampling the mean). - - Relies on internal seaborn class. only used for reporting of source data in the paper. +def get_data_for_errors_by_median(errors: pd.DataFrame, + feat_name: str, + metric_name: str, + model_column: str = 'model', + seed: int = 42) -> pd.DataFrame: + """Extract Bars with confidence intervals from seaborn plot for seaborn 0.13 and above. + Confident intervals are calculated with bootstrapping(sampling the mean). + + Parameters + ---------- + errors: pd.DataFrame + DataFrame created by `plot_errors_by_median` function + feat_name: str + feature name assigned(was transformed to 'intensity binned by median of {feat_name}') + metric_name: str + Metric used to calculate errors(MAE, MSE, etc) of intensities in bin + model_column: str + model_column in errors, defining model names """ x_axis_name = f'intensity binned by median of {feat_name}' + aggregator = EstimateAggregator("mean", ("ci", 95), n_boot=1_000, seed=seed) + # ! need to iterate over all models myself using groupby + ret = (errors + .groupby(by=[x_axis_name, model_column,], observed=True) + [[x_axis_name, model_column, metric_name]] + .apply(lambda df: aggregator(df, metric_name)) + .reset_index()) + ret.columns = ["bin", model_column, "mean", "ci_low", "ci_high"] + return ret - plotter = _BarPlotter(data=errors, x=x_axis_name, y=metric_name, hue='model', - order=None, hue_order=None, - estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=seed, - orient=None, color=None, palette=None, saturation=.75, width=.8, - errcolor=".26", errwidth=None, capsize=None, dodge=True) - ax = plt.gca() - plotter.plot(ax, {}) - plt.close(ax.get_figure()) - mean, cf_interval = plotter.statistic.flatten(), plotter.confint.reshape(-1, 2) - plotted = pd.DataFrame(np.concatenate((mean.reshape(-1, 1), cf_interval), axis=1), columns=[ - 'mean', 'ci_low', 'ci_high']) - _index = pd.DataFrame(list(itertools.product( - (_l.get_text() for _l in ax.get_xticklabels()), # bins x-axis - (_l.get_text() for _l in ax.get_legend().get_texts()), # models legend - ) - ), columns=['bin', 'model']) - plotted = pd.concat([_index, plotted], axis=1) - return plotted - - -# def get_data_for_errors_by_median_v2(errors: pd.DataFrame, feat_name, metric_name): -# from seaborn._statistics import ( -# EstimateAggregator, -# WeightedAggregator, -# ) -# from seaborn.categorical import _CategoricalAggPlotter, WeightedAggregator, EstimateAggregator -# p = _CategoricalAggPlotter( -# data=data, -# variables=dict(x=x, y=y, hue=hue, units=units, weight=weights), -# order=order, -# orient=orient, -# color=color, -# legend=legend, -# ) - -# agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator -# aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) -# err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D) def plot_rolling_error(errors: pd.DataFrame, metric_name: str, window: int = 200, diff --git a/tests/plotting/test_errors.py b/tests/plotting/test_errors.py index 9dc7ee75..b2f8f8c1 100644 --- a/tests/plotting/test_errors.py +++ b/tests/plotting/test_errors.py @@ -48,7 +48,8 @@ def expected_plotted(): plotted_path = file_dir / 'expected_plotted.csv' # ! Windows reads in new line in string characters as '\r\n' df = pd.read_csv(plotted_path, sep=',', index_col=0) - df["bin"] = df["bin"].str.replace('\r\n', '\n') + df["bin"] = df["bin"].str.replace('\r\n', '\n').astype('category') + df = df.sort_values(by=['bin', 'model']).reset_index(drop=True) return df