1
1
# filter warnings
2
2
import warnings
3
3
warnings .simplefilter (action = "ignore" , category = FutureWarning )
4
+ import os
5
+ import tensorflow as tf
6
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
7
+ tf .logging .set_verbosity (tf .logging .ERROR )
4
8
5
9
# keras imports
6
10
from keras .applications .mobilenetv2 import MobileNetV2
9
13
from keras .layers import Dense , GlobalAveragePooling2D , Input
10
14
from keras .utils import to_categorical
11
15
from keras .optimizers import SGD
16
+ from keras .callbacks import ModelCheckpoint
12
17
13
18
# other imports
14
19
import json
15
20
import datetime
16
21
import time
17
22
23
+ from utils import generate_batches
24
+
18
25
# load the user configs
19
26
with open ('conf.json' ) as f :
20
27
config = json .load (f )
25
32
test_path = config ["test_path" ]
26
33
model_path = config ["model_path" ]
27
34
batch_size = config ["batch_size" ]
35
+ epochs = config ["epochs" ]
28
36
augmented_data = config ["augmented_data" ]
29
37
validation_split = config ["validation_split" ]
30
-
38
+ epochs_after_unfreeze = config [ "epochs_after_unfreeze" ]
31
39
32
40
# create model
33
41
base_model = MobileNetV2 (include_top = False , weights = weights ,
39
47
model = Model (inputs = base_model .input , outputs = predictions )
40
48
print ("[INFO] successfully loaded base model and model..." )
41
49
50
+ # create callbacks
51
+ checkpoint = ModelCheckpoint ("logs/weights.h5" , monitor = 'loss' , save_best_only = True , period = 5 )
52
+
42
53
# start time
43
54
start = time .time ()
44
55
48
59
model .compile (optimizer = 'rmsprop' , loss = 'categorical_crossentropy' )
49
60
50
61
print ("Start training..." )
51
- train_datagen = ImageDataGenerator (
52
- rescale = 1. / 255 ,
53
- shear_range = 0.2 ,
54
- rotation_range = 0.3 ,
55
- zoom_range = 0.1 ,
56
- horizontal_flip = False ,
57
- validation_split = validation_split )
58
- train_generator = train_datagen .flow_from_directory (
59
- train_path ,
60
- target_size = (224 , 224 ),
61
- batch_size = batch_size ,
62
- save_to_dir = augmented_data )
63
- model .fit_generator (train_generator , verbose = 1 , epochs = 30 )
62
+ import glob
63
+ files = glob .glob (train_path + '/*/*jpg' )
64
+ samples = len (files )
65
+ model .fit_generator (generate_batches (train_path , batch_size ), epochs = epochs ,
66
+ steps_per_epoch = samples // batch_size , verbose = 1 , callbacks = [checkpoint ])
64
67
65
68
print ("Saving..." )
66
69
model .save (model_path + "/save_model_stage1.h5" )
67
70
68
- """
69
- print ("Visualization...")
70
- for i, layer in enumerate(base_model.layers):
71
- print(i, layer.name)
71
+ # print ("Visualization...")
72
+ # for i, layer in enumerate(base_model.layers):
73
+ # print(i, layer.name)
72
74
73
75
print ("Unfreezing all layers..." )
74
76
for i in range (len (model .layers )):
75
77
model .layers [i ].trainable = True
76
78
model .compile (optimizer = SGD (lr = 0.0001 , momentum = 0.9 ), loss = 'categorical_crossentropy' )
77
79
78
80
print ("Start training - phase 2..." )
79
- model.fit_generator(train_generator, verbose=1, epochs=1)
81
+ checkpoint = ModelCheckpoint ("logs/weights.h5" , monitor = 'loss' , save_best_only = True , period = 1 )
82
+ model .fit_generator (generate_batches (train_path , batch_size ), epochs = epochs_after_unfreeze ,
83
+ steps_per_epoch = samples // batch_size , verbose = 1 , callbacks = [checkpoint ])
80
84
81
85
print ("Saving..." )
82
86
model .save (model_path + "/save_model_stage2.h5" )
83
- """
87
+
84
88
# end time
85
89
end = time .time ()
86
90
print ("[STATUS] end time - {}" .format (datetime .datetime .now ().strftime ("%Y-%m-%d %H:%M" )))
87
- print ("[STATUS] total duration: {}" .format (end - start ))
91
+ print ("[STATUS] total duration: {}" .format (end - start ))
0 commit comments