-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
57 lines (44 loc) · 1.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from src.loader import Loader
from utils.generics_functions import plot_activation
import dotenv
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras import optimizers
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
dotenv.load_dotenv()
loader = Loader((224, 224))
dataset = loader.create_dataset()
dataset.summary()
#Balance dataset for better training
dataset.balance_data(0.8, 0.10)
dataset.summary()
vgg_conv = VGG16(weights='./vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(224, 224, 3))
for layer in vgg_conv.layers[:-8]:
layer.trainable = False
x = vgg_conv.output
x = GlobalAveragePooling2D()(x)
x = Dense(1, activation="sigmoid")(x)
model = Model(vgg_conv.input, x)
model.compile(loss = "binary_crossentropy", optimizer = optimizers.legacy.SGD(learning_rate=0.005, momentum=0.9), metrics=["accuracy"])
checkpoint_path = "model_checkpoints/checkpoint-{epoch:02d}-{val_accuracy:.2f}.h5"
checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=False,
monitor='val_accuracy',
mode='max',
save_best_only=True,
verbose=1
)
model.fit(dataset.train_images, dataset.train_labels,
epochs=25,
batch_size=128,
shuffle=True,
validation_data=(dataset.val_images, dataset.val_labels),
callbacks=[checkpoint_callback])
model.evaluate(dataset.test_images, dataset.test_labels, verbose=2)
model.save('model_miniforge')
plot_activation(model, dataset.test_images[6])
print(dataset.test_labels[6])