8
8
import matplotlib .pyplot as plt
9
9
from matplotlib .ticker import MaxNLocator
10
10
from collections import namedtuple
11
+ import sys
11
12
12
13
13
- def arr_train ():
14
+ def all_device_names ():
14
15
result_path = os .path .join (os .getcwd (),'results' )
15
- arr_train = [i for i in glob .glob (result_path + '/*training*.csv' )]
16
+ names = [i for i in glob .glob (result_path + '/*training*.csv' )]
17
+ names = [n .split ('/' )[- 1 ] for n in names ]
18
+ names = [n .split ('_' )[0 ] for n in names ]
19
+ names = set (names ) # remove duplicated names
20
+ return names
21
+
22
+ def arr_train (device_name ):
23
+ result_path = os .path .join (os .getcwd (),'results' )
24
+ arr_train = [i for i in glob .glob (result_path + '/' + device_name + '*training*.csv' )]
16
25
arr_train .sort ()
17
26
return arr_train
18
27
19
- def arr_inference ():
28
+ def arr_inference (device_name ):
20
29
result_path = os .path .join (os .getcwd (),'results' )
21
- arr_inference = [i for i in glob .glob (result_path + '/*inference*.csv' )]
30
+ arr_inference = [i for i in glob .glob (result_path + '/' + device_name + ' *inference*.csv' )]
22
31
arr_inference .sort ()
23
32
return arr_inference
24
33
25
34
26
35
def total_model (arr ,device_name ):
27
36
28
37
model_name = arr [0 ].split ('/' )[- 1 ].split ('_' )[0 ]
29
- type = arr [0 ].split ('/' )[- 1 ].split ('_' )[3 ]
38
+ type = arr [0 ].split ('/' )[- 1 ].split ('_' )[- 2 ]
30
39
n_groups = 15
31
40
32
41
double = pd .read_csv (arr [0 ])
@@ -45,9 +54,9 @@ def total_model(arr,device_name):
45
54
fig , ax = plt .subplots ()
46
55
47
56
index = np .arange (n_groups )
48
- bar_width = 0.35
57
+ bar_width = 0.25
49
58
50
- opacity = 0.4
59
+ opacity = 0.6
51
60
error_config = {'ecolor' : '0.3' }
52
61
53
62
rects1 = ax .bar (index , means_double , bar_width ,
@@ -60,21 +69,21 @@ def total_model(arr,device_name):
60
69
yerr = std_half , error_kw = error_config ,
61
70
label = 'half' )
62
71
63
- rects2 = ax .bar (index + bar_width * 2 , means_single , bar_width ,
72
+ rects3 = ax .bar (index + bar_width * 2 , means_single , bar_width ,
64
73
alpha = opacity , color = 'g' ,
65
74
yerr = std_single , error_kw = error_config ,
66
75
label = 'single' )
67
76
68
77
ax .set_xlabel ('models' )
69
78
ax .set_ylabel ('times(ms)' )
70
79
ax .set_title ("total_" + type + "_" + model_name )
71
- ax .set_xticks (index + bar_width / 2 )
72
- ax .set_xticklabels (double .columns ,rotation = 60 , fontsize = 9 )
80
+ ax .set_xticks (index + bar_width )
81
+ ax .set_xticklabels (double .columns ,rotation = 90 , fontsize = 9 )
73
82
74
83
ax .legend ()
75
84
76
85
fig .tight_layout ()
77
- plt .savefig (device_name + 'total .png' ,dpi = 400 )
86
+ plt .savefig (device_name + '_' + type + '_total .png' ,dpi = 400 )
78
87
79
88
80
89
@@ -233,3 +242,68 @@ def model_plot2(arr,model):
233
242
ax .legend ()
234
243
fig .tight_layout ()
235
244
plt .savefig (model + '_' + type + '_' + model_name + '.png' ,dpi = 300 )
245
+
246
+
247
+ # plot all given files on a single figure
248
+ def plot_on_a_single_figure (filenames ):
249
+ xlabels = []
250
+
251
+ # set plot size
252
+ plt .figure (figsize = (10 ,7.5 ))
253
+
254
+ for filename in filenames :
255
+ # retrieve infos from file name
256
+ basename_splitted = filename .split ('/' )[- 1 ].split ('_' )
257
+ model_name = basename_splitted [0 ]
258
+ type = basename_splitted [- 2 ] # training/inference
259
+ precision = basename_splitted [- 4 ] # half/single/double
260
+
261
+ # load file and sort columns
262
+ data = pd .read_csv (filename )
263
+ data = data .sort_index (axis = 1 )
264
+
265
+ # ensure all data have same x labels
266
+ if xlabels :
267
+ # compare x labels order
268
+ assert (data .columns .tolist () == xlabels )
269
+ else :
270
+ # init x labels (1st file)
271
+ xlabels = data .columns .tolist ()
272
+
273
+ # compute mean values
274
+ mean_values = data .mean ().values
275
+
276
+ # plot data
277
+ plt .plot (mean_values , label = model_name + '_' + type + '_' + precision )
278
+
279
+ # set x labels
280
+ plt .xticks (range (len (xlabels )), xlabels , rotation = 'vertical' )
281
+
282
+ # set axis labels, title, legend
283
+ plt .xlabel ('models' )
284
+ plt .ylabel ('times (ms)' )
285
+ #plt.title("Simple Plot")
286
+ plt .legend ()
287
+ plt .tight_layout ()
288
+ plt .show ()
289
+
290
+
291
+ if __name__ == '__main__' :
292
+ #import ipdb ; ipdb.set_trace()
293
+ files = sys .argv [1 :]
294
+ if files :
295
+ # plot all given files on the same figure
296
+ plot_on_a_single_figure (files )
297
+
298
+ else :
299
+ # plot all files in 'results' on separate figures
300
+ device_names = all_device_names ()
301
+ for device_name in device_names :
302
+ print ('device_name =' , device_name )
303
+
304
+ train = arr_train (device_name )
305
+ inference = arr_inference (device_name )
306
+
307
+ total_model (train ,device_name )
308
+ total_model (inference ,device_name )
309
+
0 commit comments