Skip to content

Commit

Permalink
added preprocess for imagenet, combined name of model and init weights
Browse files Browse the repository at this point in the history
  • Loading branch information
taxe10 committed Feb 20, 2024
1 parent 5c29357 commit e6f6efa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
20 changes: 14 additions & 6 deletions src/predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def call(self, inputs, training=False):
parser.add_argument('-d', '--data_info', help='path to dataframe of filepaths')
parser.add_argument('-m', '--model_dir', help='input directory')
parser.add_argument('-o', '--output_dir', help='output directory')
parser.add_argument('-p', '--parameters', help='list of training parameters')
parser.add_argument('-p', '--parameters', help='list of prediction parameters')

args = parser.parse_args()
data_parameters = DataAugmentationParams(**json.loads(args.parameters))
Expand All @@ -44,20 +44,28 @@ def call(self, inputs, training=False):

# Load trained model and parameters
loaded_model = load_model(args.model_dir+'/model.keras')
target_size = model_list[loaded_model._name]
model_name, *weights = loaded_model._name.split('_')
target_size = model_list[model_name]
custom_model = CustomModel(loaded_model) # Modify trained model to return prob and f_vec

# Prepare data generators and create a tf.data pipeline of augmented images
test_dataset, datasets_uris, data_type = get_dataset(args.data_info, shuffle=False)
predict_dataset, datasets_uris, data_type = get_dataset(args.data_info, shuffle=False)
predict_generator = predict_dataset.map(lambda x: data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log))

# Preprocess input according to the model if weights are set to imagenet
if weights == ['imagenet']:
preprocess_input_description = f"tf.keras.applications.{model_name}.preprocess_input"
preprocess_input = compile(preprocess_input_description, "<string>", 'eval')
predict_generator = predict_dataset.batch(batch_size).map(lambda x: (preprocess_input(x)))
else:
predict_generator = predict_generator.batch(batch_size)

test_generator = test_dataset.map(lambda x: data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log))
test_generator = test_generator.batch(batch_size)
with open(args.model_dir+'/class_info.json', 'r') as json_file:
classes = json.load(json_file)
class_num = len(classes)

# Start prediction process
prob, f_vec = custom_model.predict(test_generator,
prob, f_vec = custom_model.predict(predict_generator,
verbose=0,
callbacks=[PredictionCustomCallback(datasets_uris, classes)])

Expand Down
14 changes: 11 additions & 3 deletions src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,16 @@
train_generator = train_dataset.map(lambda x, y: (data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log), y))
val_generator = val_dataset.map(lambda x, y: (data_preprocessing(x, (target_size,target_size), data_type, data_parameters.log), y))

train_generator = train_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))
val_generator = val_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))
# Preprocess input according to the model if weights are set to imagenet
if weights == 'imagenet':
preprocess_input_description = f"tf.keras.applications.{nn_model}.preprocess_input"
preprocess_input = compile(preprocess_input_description, "<string>", 'eval')
train_generator = train_generator.batch(batch_size).map(lambda x, y: (preprocess_input(data_augmentation(x)), y))
val_generator = val_generator.batch(batch_size).map(lambda x, y: (preprocess_input(data_augmentation(x)), y))
else:
train_generator = train_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))
val_generator = val_generator.batch(batch_size).map(lambda x, y: (data_augmentation(x), y))

class_num = len(classes)

# Define optimizer
Expand All @@ -93,7 +101,7 @@
predictions = layers.Dense(class_num, activation='softmax', name="predictions")(x)
model = tf.keras.models.Model(inputs=base_model.input,
outputs=predictions,
name=base_model._name)
name=f"{base_model._name}_{weights}")
else:
model_description = f"tf.keras.applications.{nn_model}(include_top=True, weights=None, \
input_tensor=None, classes={class_num})"
Expand Down

0 comments on commit e6f6efa

Please sign in to comment.