@@ -336,28 +336,44 @@ def gen_report(
336336 hue_order = ["data" , "null" ],
337337 order = order ,
338338 )
339+ ax .xaxis .set_ticks_position ("top" )
340+ ax .set_xticklabels (ax .get_xticklabels (), rotation = 90 , ha = "center" )
339341 ax .set_ylabel (name )
342+ ax .legend (loc = "center right" , bbox_to_anchor = (1.2 , 0.5 ), ncol = 1 )
343+ ax .tick_params (axis = "both" , which = "both" , length = 0 )
340344 sns .despine (left = True )
345+ plt .tight_layout ()
346+
341347 import datetime
342348
343349 timestamp = datetime .datetime .utcnow ().isoformat ()
344350 timestamp = timestamp .replace (":" , "" ).replace ("-" , "" )
345351 plt .savefig (f"test-{ name } -{ timestamp } .png" )
352+ plt .close ()
346353
347354 # Create comparison stats table if the metric is a score
348355 if "score" in name :
349356 effects , pvalues , = compute_pairwise_stats (subdf )
350- plt .figure (figsize = (8 , 8 ))
357+ sns .set (style = "whitegrid" , palette = "pastel" , color_codes = True )
358+ sns .set_context ("talk" )
359+ plt .figure (figsize = (2 * len (order ), 2 * len (order )))
360+ # plt.figure(figsize=(8, 8))
351361 ax = sns .heatmap (
352362 effects ,
353363 annot = np .fix (- np .log10 (pvalues )),
354364 yticklabels = order ,
355365 xticklabels = order ,
356366 cbar = True ,
367+ cbar_kws = {"shrink" : 0.7 },
357368 square = True ,
358369 )
359370 ax .xaxis .set_ticks_position ("top" )
371+ ax .set_xticklabels (ax .get_xticklabels (), rotation = 90 , ha = "center" )
372+ ax .set_yticklabels (ax .get_yticklabels (), rotation = 0 , ha = "right" )
373+ ax .tick_params (axis = "both" , which = "both" , length = 0 )
374+ plt .tight_layout ()
360375 plt .savefig (f"stats-{ name } -{ timestamp } .png" )
376+ plt .close ()
361377 save_obj (
362378 dict (effects = effects , pvalues = pvalues , order = order ),
363379 f"stats-{ name } -{ timestamp } .pkl" ,
0 commit comments