diff --git a/ms2query/create_new_library/train_ms2deepscore.py b/ms2query/create_new_library/train_ms2deepscore.py index e279e760..82626488 100644 --- a/ms2query/create_new_library/train_ms2deepscore.py +++ b/ms2query/create_new_library/train_ms2deepscore.py @@ -43,7 +43,7 @@ def train_ms2ds_model(training_spectra, reference_scores_df=tanimoto_df, dim=len(spectrum_binner.known_bins), # The number of bins created same_prob_bins=same_prob_bins, - num_turns=2, + num_turns=1, augment_noise_max=10, augment_noise_intensity=0.01) @@ -58,9 +58,9 @@ def train_ms2ds_model(training_spectra, augment_removal_max=0, augment_removal_intensity=0, augment_intensity=0, augment_noise_max=0, use_fixed_set=True ) - model = SiameseModel(spectrum_binner, base_dims=(500, 500), embedding_dim=200, dropout_rate=0.2) + model = SiameseModel(spectrum_binner, base_dims=(1000, 1000, 1000), embedding_dim=500, dropout_rate=0.2) - model.compile(loss='mse', optimizer=Adam(lr=0.001), metrics=["mae", tf.keras.metrics.RootMeanSquaredError()]) + model.compile(loss='mse', optimizer=Adam(lr=0.0005), metrics=["mae", tf.keras.metrics.RootMeanSquaredError()]) # Save best model and include early stopping checkpointer = ModelCheckpoint(filepath=output_model_file_name, monitor='val_loss', mode="min", verbose=1, save_best_only=True)