From 9a5e3a871e36b0fcb52e3e62e3ae89010c3e5f33 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 7 Jul 2023 12:45:13 -0400 Subject: [PATCH] BUG/CLN: Add/fix log statements in FrameClassificationModel (#681) --- src/vak/models/frame_classification_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 0aa820ced..daaa7d7bb 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -192,6 +192,7 @@ def training_step(self, batch: tuple, batch_idx: int): x, y = batch[0], batch[1] out = self.network(x) loss = self.loss(out, y) + self.log(f'train_loss', loss) return loss def validation_step(self, batch: tuple, batch_idx: int): @@ -264,19 +265,19 @@ def validation_step(self, batch: tuple, batch_idx: int): # TODO: figure out smarter way to do this for metric_name, metric_callable in self.metrics.items(): if metric_name == "loss": - self.log(f'val_{metric_name}', metric_callable(out, y), batch_size=1) + self.log(f'val_{metric_name}', metric_callable(out, y), batch_size=1, on_step=True) elif metric_name == "acc": self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1) if self.post_tfm: self.log(f'val_{metric_name}_tfm', metric_callable(y_pred_tfm, y), - batch_size=1) + batch_size=1, on_step=True) elif metric_name == "levenshtein" or metric_name == "segment_error_rate": self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1) if self.post_tfm: self.log(f'val_{metric_name}_tfm', metric_callable(y_pred_tfm_labels, y_labels), - batch_size=1) + batch_size=1, on_step=True) def predict_step(self, batch: tuple, batch_idx: int): """Perform one prediction step.