forked from mlcommons/inference
-
Notifications
You must be signed in to change notification settings - Fork 0
/
accuracy-imagenet.py
74 lines (57 loc) · 2.15 KB
/
accuracy-imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json
We assume that loadgen's query index is in the same order as the images in imagenet2012/val_map.txt.
"""
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import json
import numpy as np
# pylint: disable=missing-docstring
def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument("--mlperf-accuracy-file", required=True, help="path to mlperf_log_accuracy.json")
parser.add_argument("--imagenet-val-file", required=True, help="path to imagenet val_map.txt")
parser.add_argument("--verbose", action="store_true", help="verbose messages")
parser.add_argument("--dtype", default="float32", choices=["float32", "int32", "int64"], help="data type of the label")
args = parser.parse_args()
return args
dtype_map = {
"float32": np.float32,
"int32": np.int32,
"int64": np.int64
}
def main():
args = get_args()
imagenet = []
with open(args.imagenet_val_file, "r") as f:
for line in f:
cols = line.strip().split()
imagenet.append((cols[0], int(cols[1])))
with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)
seen = set()
good = 0
for j in results:
idx = j['qsl_idx']
# de-dupe in case loadgen sends the same image multiple times
if idx in seen:
continue
seen.add(idx)
# get the expected label and image
img, label = imagenet[idx]
# reconstruct label from mlperf accuracy log
data = np.frombuffer(bytes.fromhex(j['data']), dtype_map[args.dtype])
found = int(data[0])
if label == found:
good += 1
else:
if args.verbose:
print("{}, expected: {}, found {}".format(img, label, found))
print("accuracy={:.3f}%, good={}, total={}".format(100. * good / len(seen), good, len(seen)))
if args.verbose:
print("found and ignored {} dupes".format(len(results) - len(seen)))
if __name__ == "__main__":
main()