-
Notifications
You must be signed in to change notification settings - Fork 17
/
evaluation.py
66 lines (45 loc) · 1.41 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from vgg16 import VGGNet
from vgg5 import VGGNet5
from resnet import ResNet50 as ResNet
from utils import load_mnist
def load_models():
"""Load models """
model_list = []
model_list.append(VGGNet("model/vggnet.h5"))
model_list.append(ResNet("model/resnet.h5"))
model_list.append(VGGNet5("model/vggnet5.h5"))
return model_list
def evaluate(prediction, true_labels):
"""Return an accuracy
Parameters
----------
prediction : 2-D array, shape (n_sample, n_classes)
Onehot encoded predicted array
true_labels : 2-D array, shape (n_sample, n_classes)
Onehot encoded true array
Returns
----------
accuracy : float
Return an accuracy
"""
pred = np.argmax(prediction, 1)
true = np.argmax(true_labels, 1)
equal = (pred == true)
return np.mean(equal)
def main():
model_list = load_models()
_, _, (X_test, y_test) = load_mnist()
pred_list = []
for idx, model in enumerate(model_list):
pred = model.predict(X_test)
pred_list.append(pred)
# Check a single model accuracy
acc = evaluate(pred, y_test)
print("Model-{}: {:>.5%}".format(idx, acc))
pred_list = np.asarray(pred_list)
pred_mean = np.mean(pred_list, 0)
accuracy = evaluate(pred_mean, y_test)
print("Final Test Accuracy: {:>.5%}".format(accuracy))
if __name__ == '__main__':
main()