Skip to content

Commit

Permalink
added validation set metrics logging
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 30, 2023
1 parent 6818229 commit f7d6925
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions kernel_eval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,

# for every epoch use the validation set to check if this is the best model yet
validation_acc = test_model(model, validation_dataloader, device)
wandb.log({"validation_acc": validation_acc})

if validation_acc > best_validation_acc:
best_validation_acc = validation_acc
Expand All @@ -149,7 +150,7 @@ def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu")
device: str - the device to test on (cpu or cuda)
Returns:
test_accuracy: float - the test accuracy
test_accuracy: float - the test accuracy in percent
"""
# test the model without gradient calculation and in evaluation mode
tp_val: float = 0
Expand Down Expand Up @@ -187,4 +188,4 @@ def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu")
"train_f1": f1_score, "test_acc": 100 * correct / total})
print(f"Test Accuracy: {accuracy}%")

return accuracy, precision, recall, f1_score
return (accuracy, precision, recall, f1_score)

0 comments on commit f7d6925

Please sign in to comment.