Skip to content

Commit

Permalink
split functions and avoid double computations in grid search
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Jul 26, 2024
1 parent 65388e2 commit 43bca37
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
41 changes: 28 additions & 13 deletions ms2deepscore/train_new_model/train_ms2deepscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,37 @@ def train_ms2ds_model(
):
"""Full workflow to train a MS2DeepScore model.
"""
train_generator, validation_loss_calculator = prepare_folders_and_generators(
training_spectra,
validation_spectra,
results_folder,
settings)

model = SiameseSpectralModel(settings=settings)

output_model_file_name = os.path.join(results_folder, settings.model_file_name)

history = train(model,
train_generator,
num_epochs=settings.epochs,
learning_rate=settings.learning_rate,
validation_loss_calculator=validation_loss_calculator,
patience=settings.patience,
loss_function=settings.loss_function,
checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0)
return model, history


def prepare_folders_and_generators(
training_spectra,
validation_spectra,
results_folder,
settings
):
os.makedirs(results_folder, exist_ok=True)
# Save settings
settings.save_to_file(os.path.join(results_folder, "settings.json"))

output_model_file_name = os.path.join(results_folder, settings.model_file_name)

selected_compound_pairs_training, selected_training_spectra = select_compound_pairs_wrapper(
training_spectra, settings=settings)
Expand All @@ -37,20 +63,9 @@ def train_ms2ds_model(
selected_compound_pairs=selected_compound_pairs_training,
settings=settings)

model = SiameseSpectralModel(settings=settings)

validation_loss_calculator = ValidationLossCalculator(validation_spectra,
settings=settings)

history = train(model,
train_generator,
num_epochs=settings.epochs,
learning_rate=settings.learning_rate,
validation_loss_calculator=validation_loss_calculator,
patience=settings.patience,
loss_function=settings.loss_function,
checkpoint_filename=output_model_file_name, lambda_l1=0, lambda_l2=0)
return model, history
return train_generator, validation_loss_calculator


def plot_history(losses, val_losses, file_name: Optional[str] = None):
Expand Down
57 changes: 48 additions & 9 deletions ms2deepscore/wrapper_functions/training_wrapper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from datetime import datetime
from matchms.exporting import save_spectra
from matchms.importing import load_spectra
from ms2deepscore import MS2DeepScore
from ms2deepscore.models.SiameseSpectralModel import (SiameseSpectralModel,
train)
from ms2deepscore.models.loss_functions import bin_dependent_losses
from ms2deepscore.benchmarking.calculate_scores_for_validation import \
calculate_true_values_and_predictions_for_validation_spectra
from ms2deepscore.SettingsMS2Deepscore import SettingsMS2Deepscore
from ms2deepscore.train_new_model.split_positive_and_negative_mode import \
split_by_ionmode
from ms2deepscore.train_new_model.train_ms2deepscore import train_ms2ds_model, plot_history
from ms2deepscore.train_new_model.train_ms2deepscore import \
train_ms2ds_model, plot_history, prepare_folders_and_generators
from ms2deepscore.train_new_model.validation_and_test_split import \
split_spectra_in_random_inchikey_sets
from ms2deepscore.utils import load_spectra_as_list
Expand Down Expand Up @@ -115,6 +117,7 @@ def parameter_search(
negative_validation_spectra = stored_training_data.load_negative_train_split("validation")

results = {}
train_generator = None

# Generate all combinations of setting variations
keys, values = zip(*setting_variations.items())
Expand All @@ -130,13 +133,49 @@ def parameter_search(
results_folder = os.path.join(stored_training_data.trained_models_folder, model_directory_name)

print(f"Testing combination: {params}")


fields_affecting_generators = [
"fingerprint_type",
"fingerprint_nbits",
"max_pairs_per_bin",
"same_prob_bins",
"include_diagonal",

]
search_includes_generator_parameters = False
for field in fields_affecting_generators:
if field in keys:
search_includes_generator_parameters = True

if search_includes_generator_parameters and (train_generator is None):
train_generator, validation_loss_calculator = prepare_folders_and_generators(
training_spectra,
validation_spectra,
results_folder,
settings)

model = SiameseSpectralModel(settings=settings)

output_model_file_name = os.path.join(results_folder, settings.model_file_name)

# Train model
_, history = train_ms2ds_model(
training_spectra, validation_spectra,
results_folder,
settings
)
try:
history = train(
model,
train_generator,
num_epochs=settings.epochs,
learning_rate=settings.learning_rate,
validation_loss_calculator=validation_loss_calculator,
patience=settings.patience,
loss_function=settings.loss_function,
checkpoint_filename=output_model_file_name,
lambda_l1=0, lambda_l2=0
)
except:
print("---- Model training failed! ----")
print("---- Settings ----")
print(settings.get_dict())
continue

ms2deepsore_model_file_name = os.path.join(stored_training_data.trained_models_folder,
model_directory_name,
Expand Down Expand Up @@ -309,4 +348,4 @@ def load_training_data(self,
return self.load_negative_train_split(data_split_type)
if ionisation_mode == "both":
return self.load_positive_train_split(data_split_type) + self.load_negative_train_split(data_split_type)
raise ValueError("expected ionisation mode to be 'positive', 'negative' or 'both'")
raise ValueError("expected ionisation mode to be 'positive', 'negative' or 'both'")

0 comments on commit 43bca37

Please sign in to comment.