Skip to content
This repository was archived by the owner on Dec 19, 2024. It is now read-only.

Commit 60a66fc

Browse files
Add box plots and unittests (#137)
* Add box plots and unittests * improve unit tests * fix positional args * fix positional args * change default title to none Co-authored-by: Bowen Li <[email protected]>
1 parent 64c511f commit 60a66fc

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

datasetinsights/stats/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
bar_plot,
44
grid_plot,
55
histogram_plot,
6+
model_performance_box_plot,
7+
model_performance_comparison_box_plot,
68
plot_bboxes,
79
rotation_plot,
810
)
@@ -12,6 +14,8 @@
1214
"grid_plot",
1315
"histogram_plot",
1416
"plot_bboxes",
17+
"model_performance_box_plot",
18+
"model_performance_comparison_box_plot",
1519
"rotation_plot",
1620
"RenderedObjectInfo",
1721
]

datasetinsights/stats/visualization/plots.py

+126
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
COLORS = list(ImageColor.colormap.values())
2323
FONT_SCALE = 35
2424
LINE_WIDTH_SCALE = 250
25+
ERROR_BAR_BASE_COLOR = "indianred"
26+
ERROR_BAR_COMPARE_COLOR = "lightseagreen"
2527

2628

2729
def decode_segmap(labels, dataset="cityscapes"):
@@ -370,3 +372,127 @@ def rotation_plot(df, x, y, z=None, max_samples=None, title=None, **kwargs):
370372
)
371373
)
372374
return fig
375+
376+
377+
def model_performance_comparison_box_plot(
378+
title=None,
379+
mean_ap_base=None,
380+
mean_ap_50_base=None,
381+
mean_ar_base=None,
382+
mean_ap_new=None,
383+
mean_ap_50_new=None,
384+
mean_ar_new=None,
385+
range=[0, 1.0],
386+
**kwargs,
387+
):
388+
"""Create a box plot for a base and new model performance
389+
Args:
390+
title (str): title of the plot
391+
mean_ap_base (list): a list of base mAP
392+
mean_ap_50_base (list): a list of base mAP
393+
mean_ar_base (list): a list of base mAP
394+
mean_ap_new (list): a list of base mAP
395+
mean_ap_50_new (list): a list of base mAP
396+
mean_ar_new (list): a list of base mAP
397+
range (list): the range of y axis. Defaults to [0, 1.0]
398+
399+
Returns:
400+
A plotly.graph_objects.Figure containing the box plot
401+
"""
402+
fig = go.Figure(
403+
layout=go.Layout(title=go.layout.Title(text=title), **kwargs)
404+
)
405+
fig.update_yaxes(range=range)
406+
_fig_add_trace(
407+
fig,
408+
mean_ap_base,
409+
name="baes mAP",
410+
base=True,
411+
color=ERROR_BAR_BASE_COLOR,
412+
)
413+
_fig_add_trace(
414+
fig,
415+
mean_ap_new,
416+
name="new mAP",
417+
base=False,
418+
color=ERROR_BAR_COMPARE_COLOR,
419+
)
420+
_fig_add_trace(
421+
fig,
422+
mean_ap_50_base,
423+
name="base mAP50",
424+
base=True,
425+
color=ERROR_BAR_BASE_COLOR,
426+
)
427+
_fig_add_trace(
428+
fig,
429+
mean_ap_50_new,
430+
name="new mAP50",
431+
base=False,
432+
color=ERROR_BAR_COMPARE_COLOR,
433+
)
434+
_fig_add_trace(
435+
fig,
436+
mean_ar_base,
437+
name="base mAR",
438+
base=True,
439+
color=ERROR_BAR_BASE_COLOR,
440+
)
441+
_fig_add_trace(
442+
fig,
443+
mean_ar_new,
444+
name="new mAR",
445+
base=False,
446+
color=ERROR_BAR_COMPARE_COLOR,
447+
)
448+
449+
return fig
450+
451+
452+
def model_performance_box_plot(
453+
title=None,
454+
mean_ap=None,
455+
mean_ap_50=None,
456+
mean_ar=None,
457+
range=[0, 1.0],
458+
**kwargs,
459+
):
460+
"""Create a box plot for one model performance
461+
Args:
462+
title (str): title of the plot
463+
mean_ap (list): a list of base mAP
464+
mean_ap_50 (list): a list of base mAP
465+
mean_ar (list): a list of base mAP
466+
range (list): the range of y axis. Defaults to [0, 1.0]
467+
468+
Returns:
469+
A plotly.graph_objects.Figure containing the box plot
470+
"""
471+
fig = go.Figure(
472+
layout=go.Layout(title=go.layout.Title(text=title), **kwargs)
473+
)
474+
fig.update_yaxes(range=range)
475+
_fig_add_trace(fig, mean_ap, name="mAP")
476+
_fig_add_trace(fig, mean_ap_50, name="mAP@IOU50")
477+
_fig_add_trace(fig, mean_ar, name="mAR")
478+
479+
return fig
480+
481+
482+
def _fig_add_trace(
483+
fig=None, data=[], name="", base=True, color=ERROR_BAR_BASE_COLOR
484+
):
485+
"""Add box plot in figure
486+
Args:
487+
fig (go.Figure): figure you want to add box plot in
488+
data (list): metric values
489+
name (str): name of the added box plot
490+
base (bool): whether is a base metric. Defaults to true
491+
color (str): color for box plot
492+
493+
Returns:
494+
A plotly.graph_objects.Figure containing the box plot
495+
"""
496+
if not data:
497+
return
498+
fig.add_trace(go.Box(y=data, name=name, marker_color=color))

tests/test_visual.py

+45
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
decode_segmap,
2222
histogram_plot,
2323
match_boxes,
24+
model_performance_box_plot,
25+
model_performance_comparison_box_plot,
2426
plot_bboxes,
2527
)
2628

@@ -33,6 +35,14 @@ def get_image_and_bbox():
3335
return image, bbox
3436

3537

38+
@pytest.fixture
39+
def get_evaluation_metrics():
40+
mean_ap = [0.1, 0.2, 0.3]
41+
mean_ap_50 = [0.3, 0.4, 0.5]
42+
mean_ar = [0.2, 0.3, 0.4]
43+
return [mean_ap, mean_ap_50, mean_ar]
44+
45+
3646
def test_decode_segmap():
3747
ids = list(CITYSCAPES_COLOR_MAPPING.keys())
3848
colors = list(CITYSCAPES_COLOR_MAPPING.values())
@@ -57,6 +67,41 @@ def test_histogram_plot():
5767
assert fig == mock_layout
5868

5969

70+
@patch("datasetinsights.stats.visualization.plots.go.Figure.add_trace")
71+
@patch("datasetinsights.stats.visualization.plots.go.Figure.update_yaxes")
72+
def test_model_performance_box_plot(
73+
mock_update, mock_add_trace, get_evaluation_metrics
74+
):
75+
mean_ap, mean_ap_50, mean_ar = get_evaluation_metrics
76+
title = "test plot"
77+
model_performance_box_plot(
78+
title=title, mean_ap=mean_ap, mean_ap_50=mean_ap_50, mean_ar=mean_ar
79+
)
80+
assert mock_add_trace.call_count == 3
81+
assert mock_update.call_count == 1
82+
83+
84+
@patch("datasetinsights.stats.visualization.plots.go.Figure.add_trace")
85+
@patch("datasetinsights.stats.visualization.plots.go.Figure.update_yaxes")
86+
def test_model_performance_comparison_box_plot(
87+
mock_update, mock_add_trace, get_evaluation_metrics
88+
):
89+
mean_ap_base, mean_ap_50_base, mean_ar_base = get_evaluation_metrics
90+
mean_ap_new, mean_ap_50_new, mean_ar_new = get_evaluation_metrics
91+
title = "test plot"
92+
model_performance_comparison_box_plot(
93+
title=title,
94+
mean_ap_base=mean_ap_base,
95+
mean_ap_50_base=mean_ap_50_base,
96+
mean_ar_base=mean_ar_base,
97+
mean_ap_new=mean_ap_new,
98+
mean_ap_50_new=mean_ap_50_new,
99+
mean_ar_new=mean_ar_new,
100+
)
101+
assert mock_add_trace.call_count == 6
102+
mock_update.assert_called_once()
103+
104+
60105
def test_bar_plot():
61106
df = pd.DataFrame({"x": ["a", "b", "c"], "y": [1, 2, 3]})
62107
mock_figure = Mock()

0 commit comments

Comments
 (0)