Skip to content

Commit 4a3ebcd

Browse files
authored
Merge pull request #4 from elombardi2/master
Fix bar plot layout, add a new figure
2 parents d76acb9 + 79c1a8c commit 4a3ebcd

File tree

1 file changed

+85
-11
lines changed

1 file changed

+85
-11
lines changed

plot.py

+85-11
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,34 @@
88
import matplotlib.pyplot as plt
99
from matplotlib.ticker import MaxNLocator
1010
from collections import namedtuple
11+
import sys
1112

1213

13-
def arr_train():
14+
def all_device_names():
1415
result_path=os.path.join(os.getcwd(),'results')
15-
arr_train=[i for i in glob.glob(result_path+'/*training*.csv')]
16+
names=[i for i in glob.glob(result_path+'/*training*.csv')]
17+
names=[n.split('/')[-1] for n in names]
18+
names=[n.split('_')[0] for n in names]
19+
names=set(names) # remove duplicated names
20+
return names
21+
22+
def arr_train(device_name):
23+
result_path=os.path.join(os.getcwd(),'results')
24+
arr_train=[i for i in glob.glob(result_path+'/' + device_name + '*training*.csv')]
1625
arr_train.sort()
1726
return arr_train
1827

19-
def arr_inference():
28+
def arr_inference(device_name):
2029
result_path=os.path.join(os.getcwd(),'results')
21-
arr_inference=[i for i in glob.glob(result_path+'/*inference*.csv')]
30+
arr_inference=[i for i in glob.glob(result_path+'/' + device_name + '*inference*.csv')]
2231
arr_inference.sort()
2332
return arr_inference
2433

2534

2635
def total_model(arr,device_name):
2736

2837
model_name=arr[0].split('/')[-1].split('_')[0]
29-
type=arr[0].split('/')[-1].split('_')[3]
38+
type=arr[0].split('/')[-1].split('_')[-2]
3039
n_groups = 15
3140

3241
double=pd.read_csv(arr[0])
@@ -45,9 +54,9 @@ def total_model(arr,device_name):
4554
fig, ax = plt.subplots()
4655

4756
index = np.arange(n_groups)
48-
bar_width = 0.35
57+
bar_width = 0.25
4958

50-
opacity = 0.4
59+
opacity = 0.6
5160
error_config = {'ecolor': '0.3'}
5261

5362
rects1 = ax.bar(index, means_double, bar_width,
@@ -60,21 +69,21 @@ def total_model(arr,device_name):
6069
yerr=std_half, error_kw=error_config,
6170
label='half')
6271

63-
rects2 = ax.bar(index + bar_width*2, means_single, bar_width,
72+
rects3 = ax.bar(index + bar_width*2, means_single, bar_width,
6473
alpha=opacity, color='g',
6574
yerr=std_single, error_kw=error_config,
6675
label='single')
6776

6877
ax.set_xlabel('models')
6978
ax.set_ylabel('times(ms)')
7079
ax.set_title("total_"+type+"_"+model_name)
71-
ax.set_xticks(index + bar_width / 2)
72-
ax.set_xticklabels(double.columns,rotation=60, fontsize=9)
80+
ax.set_xticks(index + bar_width)
81+
ax.set_xticklabels(double.columns,rotation=90, fontsize=9)
7382

7483
ax.legend()
7584

7685
fig.tight_layout()
77-
plt.savefig(device_name+'total.png',dpi=400)
86+
plt.savefig(device_name+'_'+type+'_total.png',dpi=400)
7887

7988

8089

@@ -233,3 +242,68 @@ def model_plot2(arr,model):
233242
ax.legend()
234243
fig.tight_layout()
235244
plt.savefig(model+'_'+type+'_'+model_name+'.png',dpi=300)
245+
246+
247+
# plot all given files on a single figure
248+
def plot_on_a_single_figure(filenames):
249+
xlabels = []
250+
251+
# set plot size
252+
plt.figure(figsize=(10,7.5))
253+
254+
for filename in filenames:
255+
# retrieve infos from file name
256+
basename_splitted = filename.split('/')[-1].split('_')
257+
model_name = basename_splitted[0]
258+
type = basename_splitted[-2] # training/inference
259+
precision = basename_splitted[-4] # half/single/double
260+
261+
# load file and sort columns
262+
data = pd.read_csv(filename)
263+
data = data.sort_index(axis=1)
264+
265+
# ensure all data have same x labels
266+
if xlabels:
267+
# compare x labels order
268+
assert (data.columns.tolist() == xlabels)
269+
else:
270+
# init x labels (1st file)
271+
xlabels = data.columns.tolist()
272+
273+
# compute mean values
274+
mean_values = data.mean().values
275+
276+
# plot data
277+
plt.plot(mean_values, label = model_name + '_' + type + '_' + precision)
278+
279+
# set x labels
280+
plt.xticks(range(len(xlabels)), xlabels, rotation='vertical')
281+
282+
# set axis labels, title, legend
283+
plt.xlabel('models')
284+
plt.ylabel('times (ms)')
285+
#plt.title("Simple Plot")
286+
plt.legend()
287+
plt.tight_layout()
288+
plt.show()
289+
290+
291+
if __name__ == '__main__':
292+
#import ipdb ; ipdb.set_trace()
293+
files = sys.argv[1:]
294+
if files:
295+
# plot all given files on the same figure
296+
plot_on_a_single_figure(files)
297+
298+
else:
299+
# plot all files in 'results' on separate figures
300+
device_names = all_device_names()
301+
for device_name in device_names:
302+
print('device_name =', device_name)
303+
304+
train=arr_train(device_name)
305+
inference=arr_inference(device_name)
306+
307+
total_model(train,device_name)
308+
total_model(inference,device_name)
309+

0 commit comments

Comments
 (0)