Skip to content

Commit

Permalink
fixed metrics calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jul 30, 2023
1 parent 0151802 commit 6818229
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
60 changes: 21 additions & 39 deletions kernel_eval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,6 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
precision: float - the precision at the end of the training
recall: float - the recall at the end of the training
"""

wandb.init(
# set the wandb project where this run will be logged
project="kernel_optimization",

# track hyperparameters and run metadata
config={
"learning_rate": str(learning_rate),
"architecture": model_type,
"dataset": "bioimages",
"epochs": str(epochs),
"depthwise": depthwise
}
)

# initialize model, loss function, optimizer and so on
train_accs: List[float] = []
train_losses: List[float] = []
Expand All @@ -96,9 +81,6 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
# also, depending on the epoch the learning rate gets adjusted before
# the network is set into training mode

tp_val: float = 0
fp_val: float = 0
fn_val: float = 0
model.train()
kbar = pkbar.Kbar(target=len(train_dataloader) - 1, epoch=epoch, num_epochs=epochs,
width=20, always_stateful=True)
Expand Down Expand Up @@ -133,12 +115,6 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
correct += predicted.eq(label).sum().item()
epoch_acc.append(100. * correct / total)

# Calculate metrics
# Update tp_val, fp_val, and fn_val counters
tp_val += ((predicted == 1) & (label == 1)).sum().item()
fp_val += ((predicted == 1) & (label == 0)).sum().item()
fn_val += ((predicted == 0) & (label == 1)).sum().item()

kbar.update(batch_idx, values=[("loss", running_loss/(batch_idx+1)),
("acc", 100. * correct / total)])
wandb.log({"train_acc": 100 * correct/total, "train_loss": running_loss/(batch_idx+1)})
Expand All @@ -151,15 +127,6 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
train_accs.append(sum(epoch_acc) / len(epoch_acc))
train_losses.append(sum(epoch_loss) / len(epoch_loss))

# Calculate average metrics for the epoch
# Calculate precision, recall, and F1 score
precision = tp_val / (tp_val + fp_val) if tp_val + fp_val > 0 else 0
recall = tp_val / (tp_val + fn_val) if tp_val + fn_val > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

wandb.log({"train_recall": recall, "train_precision": precision, "train_f1": f1_score})


# for every epoch use the validation set to check if this is the best model yet
validation_acc = test_model(model, validation_dataloader, device)

Expand All @@ -169,8 +136,7 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset,
best_epoch = epoch
save_model(mpath_out, model_type, depthwise, batch_size, learning_rate, epochs, model)

return (best_model, train_accs[best_epoch], train_accs, train_losses,
f1_score, precision, recall)
return (best_model, train_accs[best_epoch], train_accs, train_losses)


def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu") -> float:
Expand All @@ -186,7 +152,9 @@ def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu")
test_accuracy: float - the test accuracy
"""
# test the model without gradient calculation and in evaluation mode

tp_val: float = 0
fp_val: float = 0
fn_val: float = 0

with torch.no_grad():
model = model.to(device)
Expand All @@ -202,7 +170,21 @@ def test_model(model: nn.Module, dataloader: IterableDataset, device: str="cpu")
predicted = output.round()
total += label.size(0)
correct += predicted.eq(label).sum().item()
# wandb.log({"test_acc": 100 * correct / total})
print(f"Test Accuracy: {100. * correct / total}%")
# Calculate metrics
# Update tp_val, fp_val, and fn_val counters
tp_val += ((predicted == 1) & (label == 1)).sum().item()
fp_val += ((predicted == 1) & (label == 0)).sum().item()
fn_val += ((predicted == 0) & (label == 1)).sum().item()

# Calculate average metrics for the epoch
# Calculate precision, recall, and F1 score
accuracy = 100. * correct / total
precision = tp_val / (tp_val + fp_val) if tp_val + fp_val > 0 else 0
recall = tp_val / (tp_val + fn_val) if tp_val + fn_val > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

wandb.log({"train_recall": recall, "train_precision": precision,
"train_f1": f1_score, "test_acc": 100 * correct / total})
print(f"Test Accuracy: {accuracy}%")

return 100. * correct / total
return accuracy, precision, recall, f1_score
26 changes: 22 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Final, List
import torch
import torchsummary
import wandb

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -113,6 +114,22 @@ def main(gpu: int, batch_size: int, epochs: int, model_type: str,
model_name = model_type + f"_{batch_size}bs_{learning_rate}lr_{epochs}ep"
model_name += f"{'_depthwise' if depthwise else ''}"

# initialize WandB logging
wandb.init(
# set the wandb project where this run will be logged
project="kernel_optimization",
name=model_name,

# track hyperparameters and run metadata
config={
"learning_rate": str(learning_rate),
"architecture": model_type,
"dataset": "bioimages",
"epochs": str(epochs),
"depthwise": depthwise
}
)

if not eval_only:
print("[ train model ]")
model_w_data = train_model(model, train_loader, validation_loader,
Expand All @@ -122,9 +139,6 @@ def main(gpu: int, batch_size: int, epochs: int, model_type: str,
best_acc = model_w_data[1] # best accuracy
train_accs = model_w_data[2] # list of all accuracies
train_losses = model_w_data[3] # list of all train losses
f1_score = model_w_data[4]
precision = model_w_data[5]
recall = model_w_data[6]

del train_loader

Expand All @@ -142,7 +156,11 @@ def main(gpu: int, batch_size: int, epochs: int, model_type: str,
batch_size, learning_rate, epochs, model)

print("[ evaluate model ]")
test_accuracy = test_model(model, test_loader, device)
test_metrics = test_model(model, test_loader, device)
test_accuracy = test_metrics[0]
precision = test_metrics[1]
recall = test_metrics[2]
f1_score = test_metrics[3]

if not eval_only:
log_metrics(train_acc=best_acc,
Expand Down

0 comments on commit 6818229

Please sign in to comment.