-
Notifications
You must be signed in to change notification settings - Fork 89
Expand file tree
/
Copy pathauc.py
More file actions
66 lines (51 loc) · 2.18 KB
/
auc.py
File metadata and controls
66 lines (51 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from tensorflow import keras
from datetime import datetime
import numpy as np
from sklearn.metrics import log_loss, roc_auc_score
class roc_auc_callback(keras.callbacks.Callback):
def __init__(self, validation_generator):
self.validation_generator = validation_generator
def on_train_begin(self, logs={}):
return
def on_train_end(self, logs={}):
return
def on_epoch_begin(self, epoch, logs={}):
return
def on_epoch_end(self, epoch, logs={}):
print("epoch {}, {} start auc calcing...".format(epoch + 1, datetime.now()))
y_pred_list = []
y_pred = self.model.predict_generator(self.validation_generator)
print(y_pred.shape)
y_pred_list.append(y_pred)
y_pred_array = np.concatenate(y_pred_list)
y_list = []
for batch_idx in range(len(self.validation_generator)):
print("{} batch {}".format(datetime.now(), batch_idx))
X, y = self.validation_generator[batch_idx]
print(len(y))
y_list.append(y)
y_array = np.concatenate(y_list)
roc_val = roc_auc_score(y_array, y_pred_array)
print("{} roc={}".format(datetime.now(), roc_val))
logs['roc_auc_val'] = roc_val
print("{} end auc calcing...".format(datetime.now()))
def on_epoch_end2(self, epoch, logs={}):
print("{} start auc calcing...".format(datetime.now()))
y_list = []
y_pred_list = []
for epoch in range(self.validation_generator.splits):
for batch_idx in range(len(self.validation_generator)):
print("{} batch {}".format(datetime.now(), batch_idx))
X, y = self.validation_generator[batch_idx]
y_pred = self.model.predict(X)
y_list.append(y)
y_pred_list.append(y_pred)
y_pred_array = np.concatenate(y_pred_list)
y_array = np.concatenate(y_list)
roc_val = roc_auc_score(y_array, y_pred_array)
print("{} roc={}".format(datetime.now(), roc_val))
logs['roc_auc_val'] = roc_val
if __name__ == '__main__':
y_pred1 = [0.8, 0.7, 0.65, 0.55]
y_list = [1, 1, 0, 1]
print(roc_auc_score(y_list, y_pred1))