|
8 | 8 |
|
9 | 9 | # keras imports
|
10 | 10 | from keras.applications.mobilenetv2 import MobileNetV2
|
11 |
| -from keras.models import Model |
| 11 | +from keras.models import Model, load_model |
12 | 12 | from keras.layers import Dense, GlobalAveragePooling2D, Input
|
13 | 13 | from keras.utils import to_categorical
|
14 | 14 | from keras.optimizers import SGD
|
|
19 | 19 | import datetime
|
20 | 20 | import time
|
21 | 21 |
|
22 |
| -from utils import generate_batches, generate_batches_with_augmentation |
| 22 | +from utils import generate_batches, generate_batches_with_augmentation, create_folders |
23 | 23 |
|
24 | 24 | # load the user configs
|
25 | 25 | with open('conf.json') as f:
|
|
36 | 36 | validation_split = config["validation_split"]
|
37 | 37 | data_augmentation = config["data_augmentation"]
|
38 | 38 | epochs_after_unfreeze = config["epochs_after_unfreeze"]
|
| 39 | +create_folders(model_path, augmented_data) |
39 | 40 |
|
40 | 41 | # create model
|
41 |
| -base_model = MobileNetV2(include_top=False, weights=weights, |
42 |
| - input_tensor=Input(shape=(224,224,3)), input_shape=(224,224,3)) |
43 |
| -top_layers = base_model.output |
44 |
| -top_layers = GlobalAveragePooling2D()(top_layers) |
45 |
| -top_layers = Dense(1024, activation='relu')(top_layers) |
46 |
| -predictions = Dense(10, activation='softmax')(top_layers) |
47 |
| -model = Model(inputs=base_model.input, outputs=predictions) |
| 42 | +if weights=="imagenet": |
| 43 | + base_model = MobileNetV2(include_top=False, weights=weights, |
| 44 | + input_tensor=Input(shape=(224,224,3)), input_shape=(224,224,3)) |
| 45 | + top_layers = base_model.output |
| 46 | + top_layers = GlobalAveragePooling2D()(top_layers) |
| 47 | + top_layers = Dense(1024, activation='relu')(top_layers) |
| 48 | + predictions = Dense(10, activation='softmax')(top_layers) |
| 49 | + model = Model(inputs=base_model.input, outputs=predictions) |
| 50 | +else: |
| 51 | + model = load_model(weights) |
48 | 52 | print ("[INFO] successfully loaded base model and model...")
|
49 | 53 |
|
50 | 54 | # create callbacks
|
|
54 | 58 | start = time.time()
|
55 | 59 |
|
56 | 60 | print ("Freezing the base layers. Unfreeze the top 2 layers...")
|
57 |
| -for layer in base_model.layers: |
| 61 | +for layer in model.layers[:-3]: |
58 | 62 | layer.trainable = False
|
59 | 63 | model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
|
60 | 64 |
|
|
0 commit comments