diff --git a/src/predict_model.py b/src/predict_model.py index 136f799..ec329c4 100644 --- a/src/predict_model.py +++ b/src/predict_model.py @@ -35,7 +35,7 @@ def call(self, inputs, training=False): parser.add_argument('-d', '--data_info', help='path to dataframe of filepaths') parser.add_argument('-m', '--model_dir', help='input directory') parser.add_argument('-o', '--output_dir', help='output directory') - parser.add_argument('-p', '--parameters', help='list of training parameters') + parser.add_argument('-p', '--parameters', help='list of prediction parameters') args = parser.parse_args() data_parameters = DataAugmentationParams(**json.loads(args.parameters)) @@ -44,20 +44,28 @@ def call(self, inputs, training=False): # Load trained model and parameters loaded_model = load_model(args.model_dir+'/model.keras') - target_size = model_list[loaded_model._name] + model_name, *weights = loaded_model._name.split('_') + target_size = model_list[model_name] custom_model = CustomModel(loaded_model) # Modify trained model to return prob and f_vec # Prepare data generators and create a tf.data pipeline of augmented images - test_dataset, datasets_uris, data_type = get_dataset(args.data_info, shuffle=False) + predict_dataset, datasets_uris, data_type = get_dataset(args.data_info, shuffle=False) + predict_generator = predict_dataset.map(lambda x: data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log)) + + # Preprocess input according to the model if weights are set to imagenet + if weights == ['imagenet']: + preprocess_input_description = f"tf.keras.applications.{model_name}.preprocess_input" + preprocess_input = compile(preprocess_input_description, "", 'eval') + predict_generator = predict_dataset.batch(batch_size).map(lambda x: (preprocess_input(x))) + else: + predict_generator = predict_generator.batch(batch_size) - test_generator = test_dataset.map(lambda x: data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log)) - test_generator = test_generator.batch(batch_size) with open(args.model_dir+'/class_info.json', 'r') as json_file: classes = json.load(json_file) class_num = len(classes) # Start prediction process - prob, f_vec = custom_model.predict(test_generator, + prob, f_vec = custom_model.predict(predict_generator, verbose=0, callbacks=[PredictionCustomCallback(datasets_uris, classes)]) diff --git a/src/train_model.py b/src/train_model.py index bccf30e..f53f466 100644 --- a/src/train_model.py +++ b/src/train_model.py @@ -70,8 +70,16 @@ train_generator = train_dataset.map(lambda x, y: (data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log), y)) val_generator = val_dataset.map(lambda x, y: (data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log), y)) - train_generator = train_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y)) - val_generator = val_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y)) + # Preprocess input according to the model if weights are set to imagenet + if weights == 'imagenet': + preprocess_input_description = f"tf.keras.applications.{nn_model}.preprocess_input" + preprocess_input = compile(preprocess_input_description, "", 'eval') + train_generator = train_generator.batch(batch_size).map(lambda x, y: (preprocess_input(data_augmentation(x)), y)) + val_generator = val_generator.batch(batch_size).map(lambda x, y: (preprocess_input(data_augmentation(x)), y)) + else: + train_generator = train_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y)) + val_generator = val_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y)) + class_num = len(classes) # Define optimizer @@ -93,7 +101,7 @@ predictions = layers.Dense(class_num, activation='softmax', name="predictions")(x) model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions, - name=base_model._name) + name=f"{base_model._name}_{weights}") else: model_description = f"tf.keras.applications.{nn_model}(include_top=True, weights=None, \ input_tensor=None, classes={class_num})"