44import pickle
55import pandas as pd
66import numpy as np
7- from sklearn .metrics import accuracy_score , explained_variance_score
7+ from sklearn .metrics import explained_variance_score
8+ from scipy .stats import wilcoxon
89import seaborn as sns
910import matplotlib .pyplot as plt
1011
@@ -42,7 +43,6 @@ def plot_summary(summary, output_dir=None, filename="shap_plot", plot_top_n_shap
4243 plt .tight_layout ()
4344 plt .show (block = False )
4445 plt .savefig (output_dir + f"summary_{ filename } .png" , dpi = 100 )
45- return
4646
4747
4848def shaps_to_summary (
@@ -73,7 +73,6 @@ def shaps_to_summary(
7373 filename = filename ,
7474 plot_top_n_shap = plot_top_n_shap ,
7575 )
76- return
7776
7877
7978def gen_report_shap_class (results , output_dir = "./" , plot_top_n_shap = 16 ):
@@ -112,7 +111,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
112111 shaps_i = shaps [split_i ] # all shap values for this bootstrapping split
113112 y_true = y_true_and_preds [split_i ][0 ]
114113 y_pred = y_true_and_preds [split_i ][1 ]
115- #split_performance = accuracy_score(y_true, y_pred)
114+ # split_performance = accuracy_score(y_true, y_pred)
116115 split_performance = explained_variance_score (y_true , y_pred )
117116
118117 # split prediction indexes into TP, TN, FP, FN, good for error auditing
@@ -166,7 +165,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
166165 plot_top_n_shap = plot_top_n_shap ,
167166 )
168167 save_obj (indexes_all , shap_dir + "indexes_quadrant.pkl" )
169- return
168+
170169
171170def gen_report_shap_regres (results , output_dir = "./" , plot_top_n_shap = 16 ):
172171 # Create shap_dir
@@ -197,7 +196,7 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
197196 "lp" : [],
198197 "lm" : [],
199198 "um" : [],
200- "up" : []
199+ "up" : [],
201200 } # this is key with shape (F, N) where F is feature_names, N is mean shap values across splits
202201 # Obtain values for each bootstrapping split, then append summary statistics to shaps_n_splits
203202 for split_i in range (n_splits ):
@@ -208,8 +207,8 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
208207
209208 # split prediction indexes into upper, median, lower, good for error auditing
210209 indexes = {"lp" : [], "lm" : [], "um" : [], "up" : []}
211- q = np .array ([25 ,50 ,75 ])
212- prc = np .percentile (y_true ,q )
210+ q = np .array ([25 , 50 , 75 ])
211+ prc = np .percentile (y_true , q )
213212 for i in range (len (y_true )):
214213 if prc [0 ] >= y_pred [i ]:
215214 indexes ["lp" ].append (i )
@@ -259,7 +258,45 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
259258 plot_top_n_shap = plot_top_n_shap ,
260259 )
261260 save_obj (indexes_all , shap_dir + "indexes_quadrant.pkl" )
262- return
261+
262+
263+ def compute_pairwise_stats (df ):
264+ """Run Wilcoxon signed rank tests across pairs of classifiers.
265+
266+ When comparing a classifier to itself, compare to its null distribution.
267+ A one sided test is used.
268+
269+ Assumes that the dataframe has three keys: Classifier, type, and score
270+ with type referring to either the data distribution or the null distribution
271+
272+ """
273+ N = len (df .Classifier .unique ())
274+ effects = np .zeros ((N , N )) * np .nan
275+ pvalues = np .zeros ((N , N )) * np .nan
276+ for idx1 , group1 in enumerate (df .groupby ("Classifier" )):
277+ filter = group1 [1 ].apply (lambda x : x .type == "data" , axis = 1 ).values
278+ group1df = group1 [1 ].iloc [filter , :]
279+ filter = group1 [1 ].apply (lambda x : x .type == "null" , axis = 1 ).values
280+ group1nulldf = group1 [1 ].iloc [filter , :]
281+ for idx2 , group2 in enumerate (df .groupby ("Classifier" )):
282+ filter = group2 [1 ].apply (lambda x : x .type == "data" , axis = 1 ).values
283+ group2df = group2 [1 ].iloc [filter , :]
284+ if group1 [0 ] != group2 [0 ]:
285+ stat , pval = wilcoxon (
286+ group1df ["score" ].values ,
287+ group2df ["score" ].values ,
288+ alternative = "greater" ,
289+ )
290+ else :
291+ stat , pval = wilcoxon (
292+ group1df ["score" ].values ,
293+ group1nulldf ["score" ].values ,
294+ alternative = "greater" ,
295+ )
296+ effects [idx1 , idx2 ] = stat
297+ pvalues [idx1 , idx2 ] = pval
298+ return effects , pvalues
299+
263300
264301def gen_report (
265302 results , prefix , metrics , gen_shap = True , output_dir = "./" , plot_top_n_shap = 16
@@ -284,6 +321,7 @@ def gen_report(
284321 },
285322 ignore_index = True ,
286323 )
324+ order = [group [0 ] for group in df .groupby ("Classifier" )]
287325 for name , subdf in df .groupby ("metric" ):
288326 sns .set (style = "whitegrid" , palette = "pastel" , color_codes = True )
289327 sns .set_context ("talk" )
@@ -296,7 +334,7 @@ def gen_report(
296334 split = True ,
297335 inner = "quartile" ,
298336 hue_order = ["data" , "null" ],
299- order = [ group [ 0 ] for group in df . groupby ( "Classifier" )] ,
337+ order = order ,
300338 )
301339 ax .set_ylabel (name )
302340 sns .despine (left = True )
@@ -306,16 +344,39 @@ def gen_report(
306344 timestamp = timestamp .replace (":" , "" ).replace ("-" , "" )
307345 plt .savefig (f"test-{ name } -{ timestamp } .png" )
308346
347+ # Create comparison stats table if the metric is a score
348+ if "score" in name :
349+ effects , pvalues , = compute_pairwise_stats (subdf )
350+ plt .figure (figsize = (8 , 8 ))
351+ ax = sns .heatmap (
352+ effects ,
353+ annot = np .fix (- np .log10 (pvalues )),
354+ yticklabels = order ,
355+ xticklabels = order ,
356+ cbar = True ,
357+ square = True ,
358+ )
359+ ax .xaxis .set_ticks_position ("top" )
360+ plt .savefig (f"stats-{ name } -{ timestamp } .png" )
361+
309362 # create SHAP summary csv and figures
310363 if gen_shap :
311- reg_metrics = ["explained_variance_score" ,"max_error" ,
312- "mean_absolute_error" ,"mean_squared_error" ,
313- "mean_squared_log_error" ,"median_absolute_error" ,
314- "r2_score" ,"mean_poisson_deviance" ,
315- "mean_gamma_deviance"
316- ]
364+ reg_metrics = [
365+ "explained_variance_score" ,
366+ "max_error" ,
367+ "mean_absolute_error" ,
368+ "mean_squared_error" ,
369+ "mean_squared_log_error" ,
370+ "median_absolute_error" ,
371+ "r2_score" ,
372+ "mean_poisson_deviance" ,
373+ "mean_gamma_deviance" ,
374+ ]
317375 if any ([True for x in metrics if x in reg_metrics ]):
318- gen_report_shap_regres (results , output_dir = output_dir , plot_top_n_shap = plot_top_n_shap )
376+ gen_report_shap_regres (
377+ results , output_dir = output_dir , plot_top_n_shap = plot_top_n_shap
378+ )
319379 else :
320- gen_report_shap_class (results , output_dir = output_dir , plot_top_n_shap = plot_top_n_shap )
321-
380+ gen_report_shap_class (
381+ results , output_dir = output_dir , plot_top_n_shap = plot_top_n_shap
382+ )
0 commit comments