diff --git a/keytotext/trainer.py b/keytotext/trainer.py index 55950b0..814e7c9 100644 --- a/keytotext/trainer.py +++ b/keytotext/trainer.py @@ -214,6 +214,7 @@ def validation_step(self, batch, batch_size): ) val_acc = self.val_acc(outputs.logits.argmax(1), labels) self.log("val_loss", loss, prog_bar=True, logger=True) + self.log(f"val_acc", val_acc, prog_bar=True,logger=True) return loss def test_step(self, batch, batch_size):