Skip to content

Commit ad53656

Browse files
committedOct 6, 2018
several bug fixes
1 parent bfe96f2 commit ad53656

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed
 

‎conf.json

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
{
2-
"weights" : "imagenet",
2+
"weights" : "model_data/trained_1.h5",
33

4-
"train_path" : "test_data",
4+
"train_path" : "train_data",
55
"test_path" : "test_data",
66
"augmented_data" : "augmented_data",
77
"model_path" : "model_data",
88

99
"validation_split": 0.15,
1010
"batch_size" : 64,
11-
"epochs" : 40,
12-
"epochs_after_unfreeze" : 10,
11+
"epochs" : 50,
12+
"epochs_after_unfreeze" : 5,
1313
"data_augmentation" : false
1414
}

‎train.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# keras imports
1010
from keras.applications.mobilenetv2 import MobileNetV2
11-
from keras.models import Model
11+
from keras.models import Model, load_model
1212
from keras.layers import Dense, GlobalAveragePooling2D, Input
1313
from keras.utils import to_categorical
1414
from keras.optimizers import SGD
@@ -19,7 +19,7 @@
1919
import datetime
2020
import time
2121

22-
from utils import generate_batches, generate_batches_with_augmentation
22+
from utils import generate_batches, generate_batches_with_augmentation, create_folders
2323

2424
# load the user configs
2525
with open('conf.json') as f:
@@ -36,15 +36,19 @@
3636
validation_split = config["validation_split"]
3737
data_augmentation = config["data_augmentation"]
3838
epochs_after_unfreeze = config["epochs_after_unfreeze"]
39+
create_folders(model_path, augmented_data)
3940

4041
# create model
41-
base_model = MobileNetV2(include_top=False, weights=weights,
42-
input_tensor=Input(shape=(224,224,3)), input_shape=(224,224,3))
43-
top_layers = base_model.output
44-
top_layers = GlobalAveragePooling2D()(top_layers)
45-
top_layers = Dense(1024, activation='relu')(top_layers)
46-
predictions = Dense(10, activation='softmax')(top_layers)
47-
model = Model(inputs=base_model.input, outputs=predictions)
42+
if weights=="imagenet":
43+
base_model = MobileNetV2(include_top=False, weights=weights,
44+
input_tensor=Input(shape=(224,224,3)), input_shape=(224,224,3))
45+
top_layers = base_model.output
46+
top_layers = GlobalAveragePooling2D()(top_layers)
47+
top_layers = Dense(1024, activation='relu')(top_layers)
48+
predictions = Dense(10, activation='softmax')(top_layers)
49+
model = Model(inputs=base_model.input, outputs=predictions)
50+
else:
51+
model = load_model(weights)
4852
print ("[INFO] successfully loaded base model and model...")
4953

5054
# create callbacks
@@ -54,7 +58,7 @@
5458
start = time.time()
5559

5660
print ("Freezing the base layers. Unfreeze the top 2 layers...")
57-
for layer in base_model.layers:
61+
for layer in model.layers[:-3]:
5862
layer.trainable = False
5963
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
6064

‎utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import numpy as np
55
from keras.utils import to_categorical
66
from keras.preprocessing.image import ImageDataGenerator
7+
import os
78

89
def generate_batches(path, batchSize):
10+
while True:
911
files = glob.glob(path + '/*/*jpg')
1012
for f in range(0, len(files), batchSize):
1113
x = []
@@ -29,3 +31,12 @@ def generate_batches_with_augmentation(train_path, batch_size, validation_split,
2931
batch_size=batch_size,
3032
save_to_dir=augmented_data)
3133
return train_generator
34+
35+
def create_folders(model_path, augmented_data):
36+
if not os.path.exists(model_path):
37+
os.mkdir(model_path)
38+
if not os.path.exists(augmented_data):
39+
os.mkdir(augmented_data)
40+
if not os.path.exists("logs"):
41+
os.mkdir("logs")
42+

0 commit comments

Comments
 (0)
Please sign in to comment.