Skip to content

Commit a1777c1

Browse files
committed
change batches generator method
1 parent 069f75f commit a1777c1

File tree

4 files changed

+49
-25
lines changed

4 files changed

+49
-25
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,5 @@ venv.bak/
107107
train_data/
108108
test_data/
109109
model_data/
110-
augmented_data/
110+
augmented_data/
111+
logs/

conf.json

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
{
22
"weights" : "imagenet",
33

4-
"train_path" : "train_data",
4+
"train_path" : "test_data",
55
"test_path" : "test_data",
66
"augmented_data" : "augmented_data",
77
"model_path" : "model_data",
88

9-
"validation_split": 0.10,
10-
"batch_size" : 64
9+
"validation_split": 0.15,
10+
"batch_size" : 64,
11+
"epochs" : 1,
12+
"epochs_after_unfreeze" : 1
1113
}

train.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# filter warnings
22
import warnings
33
warnings.simplefilter(action="ignore", category=FutureWarning)
4+
import os
5+
import tensorflow as tf
6+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
7+
tf.logging.set_verbosity(tf.logging.ERROR)
48

59
# keras imports
610
from keras.applications.mobilenetv2 import MobileNetV2
@@ -9,12 +13,15 @@
913
from keras.layers import Dense, GlobalAveragePooling2D, Input
1014
from keras.utils import to_categorical
1115
from keras.optimizers import SGD
16+
from keras.callbacks import ModelCheckpoint
1217

1318
# other imports
1419
import json
1520
import datetime
1621
import time
1722

23+
from utils import generate_batches
24+
1825
# load the user configs
1926
with open('conf.json') as f:
2027
config = json.load(f)
@@ -25,9 +32,10 @@
2532
test_path = config["test_path"]
2633
model_path = config["model_path"]
2734
batch_size = config["batch_size"]
35+
epochs = config["epochs"]
2836
augmented_data = config["augmented_data"]
2937
validation_split = config["validation_split"]
30-
38+
epochs_after_unfreeze = config["epochs_after_unfreeze"]
3139

3240
# create model
3341
base_model = MobileNetV2(include_top=False, weights=weights,
@@ -39,6 +47,9 @@
3947
model = Model(inputs=base_model.input, outputs=predictions)
4048
print ("[INFO] successfully loaded base model and model...")
4149

50+
# create callbacks
51+
checkpoint = ModelCheckpoint("logs/weights.h5", monitor='loss', save_best_only=True, period=5)
52+
4253
# start time
4354
start = time.time()
4455

@@ -48,40 +59,33 @@
4859
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
4960

5061
print ("Start training...")
51-
train_datagen = ImageDataGenerator(
52-
rescale=1./255,
53-
shear_range=0.2,
54-
rotation_range=0.3,
55-
zoom_range=0.1,
56-
horizontal_flip=False,
57-
validation_split=validation_split)
58-
train_generator = train_datagen.flow_from_directory(
59-
train_path,
60-
target_size=(224, 224),
61-
batch_size=batch_size,
62-
save_to_dir=augmented_data)
63-
model.fit_generator(train_generator, verbose=1, epochs=30)
62+
import glob
63+
files = glob.glob(train_path + '/*/*jpg')
64+
samples = len(files)
65+
model.fit_generator(generate_batches(train_path, batch_size), epochs=epochs,
66+
steps_per_epoch=samples//batch_size, verbose=1, callbacks=[checkpoint])
6467

6568
print ("Saving...")
6669
model.save(model_path + "/save_model_stage1.h5")
6770

68-
"""
69-
print ("Visualization...")
70-
for i, layer in enumerate(base_model.layers):
71-
print(i, layer.name)
71+
# print ("Visualization...")
72+
# for i, layer in enumerate(base_model.layers):
73+
# print(i, layer.name)
7274

7375
print ("Unfreezing all layers...")
7476
for i in range(len(model.layers)):
7577
model.layers[i].trainable = True
7678
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy')
7779

7880
print ("Start training - phase 2...")
79-
model.fit_generator(train_generator, verbose=1, epochs=1)
81+
checkpoint = ModelCheckpoint("logs/weights.h5", monitor='loss', save_best_only=True, period=1)
82+
model.fit_generator(generate_batches(train_path, batch_size), epochs=epochs_after_unfreeze,
83+
steps_per_epoch=samples//batch_size, verbose=1, callbacks=[checkpoint])
8084

8185
print ("Saving...")
8286
model.save(model_path + "/save_model_stage2.h5")
83-
"""
87+
8488
# end time
8589
end = time.time()
8690
print ("[STATUS] end time - {}".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M")))
87-
print ("[STATUS] total duration: {}".format(end - start))
91+
print ("[STATUS] total duration: {}".format(end - start))

utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
import glob
3+
import cv2
4+
import numpy as np
5+
from keras.utils import to_categorical
6+
7+
def generate_batches(path, batchSize):
8+
files = glob.glob(path + '/*/*jpg')
9+
for f in range(0, len(files), batchSize):
10+
x = []
11+
y = []
12+
for i in range(f, f+batchSize):
13+
if i < len(files):
14+
img = cv2.imread(files[i])
15+
x.append(cv2.resize(img, (224, 224)))
16+
y.append(int(files[i].split('/')[1]))
17+
yield (np.array(x), to_categorical(y, num_classes=10))

0 commit comments

Comments
 (0)