Skip to content

Commit

Permalink
added f1, precision and recall to logging
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 25, 2023
1 parent 54104d6 commit 0151802
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
6 changes: 5 additions & 1 deletion kernel_eval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
train_accuracy: float - the accuracy at the end of the training
train_losses: List[float] - the losses at the end of each epoch
train_accs: List[float] - the accuracies at the end of each epoch
f1_score: float - the f1 score at the end of the training
precision: float - the precision at the end of the training
recall: float - the recall at the end of the training
"""

wandb.init(
Expand Down Expand Up @@ -166,7 +169,8 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
best_epoch = epoch
save_model(mpath_out, model_type, depthwise, batch_size, learning_rate, epochs, model)

return best_model, train_accs[best_epoch], train_accs, train_losses
return (best_model, train_accs[best_epoch], train_accs, train_losses,
f1_score, precision, recall)


def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu") -> float:
Expand Down
10 changes: 8 additions & 2 deletions kernel_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,18 @@ def create_1gb_random_array() -> None:
np.save("1gb_array.npy", arr)


def log_metrics(train_acc: float, test_acc: torch.Tensor, model_name: str) -> None:
def log_metrics(train_acc: float, test_acc: torch.Tensor, model_name: str,
f1_score: float, precision: float, recall: float) -> None:
"""
Logs score and loss for a model over epochs and saves the log under ./logs/model_name.log
Parameters:
scores: torch.Tensor with the current scoreof a given model
loss: torch.Tensor with the current loss of a given model
epoch: current epoch
model_name: the name of the model
f1_score: the f1 score of the model
precision: the precision of the model
recall: the recall of the model
Returns:
None
"""
Expand All @@ -108,7 +112,9 @@ def log_metrics(train_acc: float, test_acc: torch.Tensor, model_name: str) -> No
try:
with open(f"./logs/{model_name}.log", encoding="utf-8", mode="a") as log_file:
log_file.write(f"{datetime.now().strftime('%A, %d. %B %Y %I:%M%p')}" \
f" - train_acc: {train_acc} - test_acc: {test_acc}\n")
f" - train_acc: {train_acc} - test_acc: {test_acc}" \
f" - f1_score: {f1_score} - precision: {precision}" \
f" - recall: {recall}\n")
except OSError as error:
print(f"Could not write logs into /logs/{model_name}.log - error: {error}")

Expand Down
22 changes: 16 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,17 @@ def main(gpu: int, batch_size: int, epochs: int, model_type: str,

if not eval_only:
print("[ train model ]")
model, best_acc, train_accs, train_losses = train_model(model, train_loader,
validation_loader, learning_rate,
epochs, batch_size, device,
model_type, depthwise,
MODEL_OUTPUT_PATH)
model_w_data = train_model(model, train_loader, validation_loader,
learning_rate, epochs, batch_size, device,
model_type, depthwise, MODEL_OUTPUT_PATH)
model = model_w_data[0] # the trained model itself
best_acc = model_w_data[1] # best accuracy
train_accs = model_w_data[2] # list of all accuracies
train_losses = model_w_data[3] # list of all train losses
f1_score = model_w_data[4]
precision = model_w_data[5]
recall = model_w_data[6]

del train_loader

# -------- Test Models and Evaluate Kernels ------------
Expand All @@ -140,7 +146,11 @@ def main(gpu: int, batch_size: int, epochs: int, model_type: str,

if not eval_only:
log_metrics(train_acc=best_acc,
test_acc=test_accuracy, model_name=model_name)
test_acc=test_accuracy,
model_name=model_name,
f1_score=f1_score,
precision=precision,
recall=recall)
plot_metrics(train_acc=train_accs,
train_loss=train_losses, model_name=model_name)

Expand Down

0 comments on commit 0151802

Please sign in to comment.