From 160c2f69ee7d838aeb2f887778b75bb1eb0283ed Mon Sep 17 00:00:00 2001 From: dimakarp1996 Date: Tue, 31 Jan 2023 22:44:16 +0300 Subject: [PATCH] Added option "always save model" --- deeppavlov/core/trainers/nn_trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/deeppavlov/core/trainers/nn_trainer.py b/deeppavlov/core/trainers/nn_trainer.py index 790b188937..d8bf34a677 100644 --- a/deeppavlov/core/trainers/nn_trainer.py +++ b/deeppavlov/core/trainers/nn_trainer.py @@ -76,6 +76,8 @@ class NNTrainer(FitTrainer): log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``) max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``, ignored if negative (default is ``-1``) + always_save_model: if True, we always save the obtained weights of our model, regardless of the metric. + (default if ``False``) **kwargs: additional parameters whose names will be logged but otherwise ignored @@ -107,6 +109,7 @@ def __init__(self, chainer_config: dict, *, validate_first: bool = True, validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1, log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1, + always_save_model: bool = False, **kwargs) -> None: super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets, show_examples=show_examples, max_test_batches=max_test_batches, **kwargs) @@ -141,6 +144,7 @@ def _improved(op): self.max_epochs = epochs self.epoch = start_epoch_num self.max_batches = max_batches + self.always_save_model = always_save_model self.train_batches_seen = 0 self.examples = 0 @@ -207,6 +211,11 @@ def _validate(self, iterator: DataLearningIterator, self.score_best = score log.info('Saving model') self.save() + elif self.always_save_model: + log.info(f'Changed {m_name} from {self.score_best} to {score}') + self.score_best = score + log.info('But due to always_save_model, saving the model') + self.save() else: log.info('Did not improve on the {} of {}'.format(m_name, self.score_best))