diff --git a/src/train_model.py b/src/train_model.py index a33c655..d79a3b9 100644 --- a/src/train_model.py +++ b/src/train_model.py @@ -113,15 +113,19 @@ model.summary() # fit model while also keeping track of data for dash plots - model.fit(train_generator, - validation_data=val_generator, - epochs=epochs, - verbose=0, - callbacks=[TrainCustomCallback()], - shuffle=data_parameters.shuffle) + history = model.fit( + train_generator, + validation_data=val_generator, + epochs=epochs, + verbose=0, + callbacks=[TrainCustomCallback()], + shuffle=data_parameters.shuffle + ) # save model model.save(args.output_dir+'/model.keras') with open(args.output_dir+'/class_info.json', 'w') as json_file: json.dump(classes, json_file) + with open(args.output_dir+'/history.json', 'w') as json_file: + json.dump(history.history, json_file) logging.info("Training process completed") \ No newline at end of file