Skip to content

Commit 1c1f484

Browse files
committed
enh: add generation of model comparison statistic
1 parent 8a4ea5e commit 1c1f484

File tree

1 file changed

+80
-19
lines changed

1 file changed

+80
-19
lines changed

pydra_ml/report.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pickle
55
import pandas as pd
66
import numpy as np
7-
from sklearn.metrics import accuracy_score, explained_variance_score
7+
from sklearn.metrics import explained_variance_score
8+
from scipy.stats import wilcoxon
89
import seaborn as sns
910
import matplotlib.pyplot as plt
1011

@@ -42,7 +43,6 @@ def plot_summary(summary, output_dir=None, filename="shap_plot", plot_top_n_shap
4243
plt.tight_layout()
4344
plt.show(block=False)
4445
plt.savefig(output_dir + f"summary_{filename}.png", dpi=100)
45-
return
4646

4747

4848
def shaps_to_summary(
@@ -73,7 +73,6 @@ def shaps_to_summary(
7373
filename=filename,
7474
plot_top_n_shap=plot_top_n_shap,
7575
)
76-
return
7776

7877

7978
def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
@@ -112,7 +111,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
112111
shaps_i = shaps[split_i] # all shap values for this bootstrapping split
113112
y_true = y_true_and_preds[split_i][0]
114113
y_pred = y_true_and_preds[split_i][1]
115-
#split_performance = accuracy_score(y_true, y_pred)
114+
# split_performance = accuracy_score(y_true, y_pred)
116115
split_performance = explained_variance_score(y_true, y_pred)
117116

118117
# split prediction indexes into TP, TN, FP, FN, good for error auditing
@@ -166,7 +165,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
166165
plot_top_n_shap=plot_top_n_shap,
167166
)
168167
save_obj(indexes_all, shap_dir + "indexes_quadrant.pkl")
169-
return
168+
170169

171170
def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
172171
# Create shap_dir
@@ -197,7 +196,7 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
197196
"lp": [],
198197
"lm": [],
199198
"um": [],
200-
"up": []
199+
"up": [],
201200
} # this is key with shape (F, N) where F is feature_names, N is mean shap values across splits
202201
# Obtain values for each bootstrapping split, then append summary statistics to shaps_n_splits
203202
for split_i in range(n_splits):
@@ -208,8 +207,8 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
208207

209208
# split prediction indexes into upper, median, lower, good for error auditing
210209
indexes = {"lp": [], "lm": [], "um": [], "up": []}
211-
q=np.array([25,50,75])
212-
prc=np.percentile(y_true,q)
210+
q = np.array([25, 50, 75])
211+
prc = np.percentile(y_true, q)
213212
for i in range(len(y_true)):
214213
if prc[0] >= y_pred[i]:
215214
indexes["lp"].append(i)
@@ -259,7 +258,45 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
259258
plot_top_n_shap=plot_top_n_shap,
260259
)
261260
save_obj(indexes_all, shap_dir + "indexes_quadrant.pkl")
262-
return
261+
262+
263+
def compute_pairwise_stats(df):
264+
"""Run Wilcoxon signed rank tests across pairs of classifiers.
265+
266+
When comparing a classifier to itself, compare to its null distribution.
267+
A one sided test is used.
268+
269+
Assumes that the dataframe has three keys: Classifier, type, and score
270+
with type referring to either the data distribution or the null distribution
271+
272+
"""
273+
N = len(df.Classifier.unique())
274+
effects = np.zeros((N, N)) * np.nan
275+
pvalues = np.zeros((N, N)) * np.nan
276+
for idx1, group1 in enumerate(df.groupby("Classifier")):
277+
filter = group1[1].apply(lambda x: x.type == "data", axis=1).values
278+
group1df = group1[1].iloc[filter, :]
279+
filter = group1[1].apply(lambda x: x.type == "null", axis=1).values
280+
group1nulldf = group1[1].iloc[filter, :]
281+
for idx2, group2 in enumerate(df.groupby("Classifier")):
282+
filter = group2[1].apply(lambda x: x.type == "data", axis=1).values
283+
group2df = group2[1].iloc[filter, :]
284+
if group1[0] != group2[0]:
285+
stat, pval = wilcoxon(
286+
group1df["score"].values,
287+
group2df["score"].values,
288+
alternative="greater",
289+
)
290+
else:
291+
stat, pval = wilcoxon(
292+
group1df["score"].values,
293+
group1nulldf["score"].values,
294+
alternative="greater",
295+
)
296+
effects[idx1, idx2] = stat
297+
pvalues[idx1, idx2] = pval
298+
return effects, pvalues
299+
263300

264301
def gen_report(
265302
results, prefix, metrics, gen_shap=True, output_dir="./", plot_top_n_shap=16
@@ -284,6 +321,7 @@ def gen_report(
284321
},
285322
ignore_index=True,
286323
)
324+
order = [group[0] for group in df.groupby("Classifier")]
287325
for name, subdf in df.groupby("metric"):
288326
sns.set(style="whitegrid", palette="pastel", color_codes=True)
289327
sns.set_context("talk")
@@ -296,7 +334,7 @@ def gen_report(
296334
split=True,
297335
inner="quartile",
298336
hue_order=["data", "null"],
299-
order=[group[0] for group in df.groupby("Classifier")],
337+
order=order,
300338
)
301339
ax.set_ylabel(name)
302340
sns.despine(left=True)
@@ -306,16 +344,39 @@ def gen_report(
306344
timestamp = timestamp.replace(":", "").replace("-", "")
307345
plt.savefig(f"test-{name}-{timestamp}.png")
308346

347+
# Create comparison stats table if the metric is a score
348+
if "score" in name:
349+
effects, pvalues, = compute_pairwise_stats(subdf)
350+
plt.figure(figsize=(8, 8))
351+
ax = sns.heatmap(
352+
effects,
353+
annot=np.fix(-np.log10(pvalues)),
354+
yticklabels=order,
355+
xticklabels=order,
356+
cbar=True,
357+
square=True,
358+
)
359+
ax.xaxis.set_ticks_position("top")
360+
plt.savefig(f"stats-{name}-{timestamp}.png")
361+
309362
# create SHAP summary csv and figures
310363
if gen_shap:
311-
reg_metrics=["explained_variance_score","max_error",
312-
"mean_absolute_error","mean_squared_error",
313-
"mean_squared_log_error","median_absolute_error",
314-
"r2_score","mean_poisson_deviance",
315-
"mean_gamma_deviance"
316-
]
364+
reg_metrics = [
365+
"explained_variance_score",
366+
"max_error",
367+
"mean_absolute_error",
368+
"mean_squared_error",
369+
"mean_squared_log_error",
370+
"median_absolute_error",
371+
"r2_score",
372+
"mean_poisson_deviance",
373+
"mean_gamma_deviance",
374+
]
317375
if any([True for x in metrics if x in reg_metrics]):
318-
gen_report_shap_regres(results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap)
376+
gen_report_shap_regres(
377+
results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap
378+
)
319379
else:
320-
gen_report_shap_class(results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap)
321-
380+
gen_report_shap_class(
381+
results, output_dir=output_dir, plot_top_n_shap=plot_top_n_shap
382+
)

0 commit comments

Comments
 (0)