Skip to content

Commit

Permalink
BUG/CLN: Add/fix log statements in FrameClassificationModel (#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave authored Jul 7, 2023
1 parent a2ae654 commit 9a5e3a8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/vak/models/frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def training_step(self, batch: tuple, batch_idx: int):
x, y = batch[0], batch[1]
out = self.network(x)
loss = self.loss(out, y)
self.log(f'train_loss', loss)
return loss

def validation_step(self, batch: tuple, batch_idx: int):
Expand Down Expand Up @@ -264,19 +265,19 @@ def validation_step(self, batch: tuple, batch_idx: int):
# TODO: figure out smarter way to do this
for metric_name, metric_callable in self.metrics.items():
if metric_name == "loss":
self.log(f'val_{metric_name}', metric_callable(out, y), batch_size=1)
self.log(f'val_{metric_name}', metric_callable(out, y), batch_size=1, on_step=True)
elif metric_name == "acc":
self.log(f'val_{metric_name}', metric_callable(y_pred, y), batch_size=1)
if self.post_tfm:
self.log(f'val_{metric_name}_tfm',
metric_callable(y_pred_tfm, y),
batch_size=1)
batch_size=1, on_step=True)
elif metric_name == "levenshtein" or metric_name == "segment_error_rate":
self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)
if self.post_tfm:
self.log(f'val_{metric_name}_tfm',
metric_callable(y_pred_tfm_labels, y_labels),
batch_size=1)
batch_size=1, on_step=True)

def predict_step(self, batch: tuple, batch_idx: int):
"""Perform one prediction step.
Expand Down

0 comments on commit 9a5e3a8

Please sign in to comment.