diff --git a/ms2deepscore/train_new_model/train_ms2deepscore.py b/ms2deepscore/train_new_model/train_ms2deepscore.py index b9a6ff02..38277d20 100644 --- a/ms2deepscore/train_new_model/train_ms2deepscore.py +++ b/ms2deepscore/train_new_model/train_ms2deepscore.py @@ -23,11 +23,37 @@ def train_ms2ds_model( ): """Full workflow to train a MS2DeepScore model. """ + train_generator, validation_loss_calculator = prepare_folders_and_generators( + training_spectra, + validation_spectra, + results_folder, + settings) + + model = SiameseSpectralModel(settings=settings) + + output_model_file_name = os.path.join(results_folder, settings.model_file_name) + + history = train(model, + train_generator, + num_epochs=settings.epochs, + learning_rate=settings.learning_rate, + validation_loss_calculator=validation_loss_calculator, + patience=settings.patience, + loss_function=settings.loss_function, + checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0) + return model, history + + +def prepare_folders_and_generators( + training_spectra, + validation_spectra, + results_folder, + settings + ): os.makedirs(results_folder, exist_ok=True) # Save settings settings.save_to_file(os.path.join(results_folder, "settings.json")) - output_model_file_name = os.path.join(results_folder, settings.model_file_name) selected_compound_pairs_training, selected_training_spectra = select_compound_pairs_wrapper( training_spectra, settings=settings) @@ -37,20 +63,9 @@ def train_ms2ds_model( selected_compound_pairs=selected_compound_pairs_training, settings=settings) - model = SiameseSpectralModel(settings=settings) - validation_loss_calculator = ValidationLossCalculator(validation_spectra, settings=settings) - - history = train(model, - train_generator, - num_epochs=settings.epochs, - learning_rate=settings.learning_rate, - validation_loss_calculator=validation_loss_calculator, - patience=settings.patience, - loss_function=settings.loss_function, - checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0) - return model, history + return train_generator, validation_loss_calculator def plot_history(losses, val_losses, file_name: Optional[str] = None): diff --git a/ms2deepscore/wrapper_functions/training_wrapper_functions.py b/ms2deepscore/wrapper_functions/training_wrapper_functions.py index 9cd3cc96..e1200d05 100644 --- a/ms2deepscore/wrapper_functions/training_wrapper_functions.py +++ b/ms2deepscore/wrapper_functions/training_wrapper_functions.py @@ -7,14 +7,16 @@ from datetime import datetime from matchms.exporting import save_spectra from matchms.importing import load_spectra -from ms2deepscore import MS2DeepScore +from ms2deepscore.models.SiameseSpectralModel import (SiameseSpectralModel, + train) from ms2deepscore.models.loss_functions import bin_dependent_losses from ms2deepscore.benchmarking.calculate_scores_for_validation import \ calculate_true_values_and_predictions_for_validation_spectra from ms2deepscore.SettingsMS2Deepscore import SettingsMS2Deepscore from ms2deepscore.train_new_model.split_positive_and_negative_mode import \ split_by_ionmode -from ms2deepscore.train_new_model.train_ms2deepscore import train_ms2ds_model, plot_history +from ms2deepscore.train_new_model.train_ms2deepscore import \ + train_ms2ds_model, plot_history, prepare_folders_and_generators from ms2deepscore.train_new_model.validation_and_test_split import \ split_spectra_in_random_inchikey_sets from ms2deepscore.utils import load_spectra_as_list @@ -115,6 +117,7 @@ def parameter_search( negative_validation_spectra = stored_training_data.load_negative_train_split("validation") results = {} + train_generator = None # Generate all combinations of setting variations keys, values = zip(*setting_variations.items()) @@ -130,13 +133,49 @@ def parameter_search( results_folder = os.path.join(stored_training_data.trained_models_folder, model_directory_name) print(f"Testing combination: {params}") - + + fields_affecting_generators = [ + "fingerprint_type", + "fingerprint_nbits", + "max_pairs_per_bin", + "same_prob_bins", + "include_diagonal", + + ] + search_includes_generator_parameters = False + for field in fields_affecting_generators: + if field in keys: + search_includes_generator_parameters = True + + if search_includes_generator_parameters and (train_generator is None): + train_generator, validation_loss_calculator = prepare_folders_and_generators( + training_spectra, + validation_spectra, + results_folder, + settings) + + model = SiameseSpectralModel(settings=settings) + + output_model_file_name = os.path.join(results_folder, settings.model_file_name) + # Train model - _, history = train_ms2ds_model( - training_spectra, validation_spectra, - results_folder, - settings - ) + try: + history = train( + model, + train_generator, + num_epochs=settings.epochs, + learning_rate=settings.learning_rate, + validation_loss_calculator=validation_loss_calculator, + patience=settings.patience, + loss_function=settings.loss_function, + checkpoint_filename=output_model_file_name, + lambda_l1=0, lambda_l2=0 + ) + except: + print("---- Model training failed! ----") + print("---- Settings ----") + print(settings.get_dict()) + continue ms2deepsore_model_file_name = os.path.join(stored_training_data.trained_models_folder, model_directory_name, @@ -309,4 +348,4 @@ def load_training_data(self, return self.load_negative_train_split(data_split_type) if ionisation_mode == "both": return self.load_positive_train_split(data_split_type) + self.load_negative_train_split(data_split_type) - raise ValueError("expected ionisation mode to be 'positive', 'negative' or 'both'") \ No newline at end of file + raise ValueError("expected ionisation mode to be 'positive', 'negative' or 'both'")