Skip to content

Commit

Permalink
code style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LostOxygen committed Jun 16, 2023
1 parent 40c0fd3 commit 34b51cb
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions kernel_eval/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from torch import optim
from torch import nn
from torch.utils.data import IterableDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pkbar
from tqdm import tqdm
from kernel_eval.utils import save_model
import wandb
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau


def adjust_learning_rate(optimizer, epoch: int, epochs: int, learning_rate: int) -> None:
Expand All @@ -35,7 +34,8 @@ def adjust_learning_rate(optimizer, epoch: int, epochs: int, learning_rate: int)
param_group["lr"] = new_lr


def train_model(model: nn.Module, train_dataloader: IterableDataset, validation_dataloader: IterableDataset,
def train_model(model: nn.Module, train_dataloader: IterableDataset,
validation_dataloader: IterableDataset,
learning_rate: float, epochs: int, batch_size: int,
device: str = "cpu", model_type: str = "no_model", depthwise: bool = True,
mpath_out: str = "./models/") -> Union[nn.Module, float, List[float], List[float]]:
Expand Down Expand Up @@ -93,18 +93,18 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset, validation_
# also, depending on the epoch the learning rate gets adjusted before
# the network is set into training mode

TP = 0
FP = 0
FN = 0
tp_val: float = 0
fp_val: float = 0
fn_val: float = 0
model.train()
kbar = pkbar.Kbar(target=len(train_dataloader) - 1, epoch=epoch, num_epochs=epochs,
width=20, always_stateful=True)

correct = 0
total = 0
running_loss = 0.0
epoch_loss = []
epoch_acc = []
correct: int = 0
total: int = 0
running_loss: float = 0.0
epoch_loss: List[float] = []
epoch_acc: List[float] = []
# adjust the learning rate
# adjust_learning_rate(optimizer, epoch, epochs, learning_rate)

Expand All @@ -131,10 +131,10 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset, validation_
epoch_acc.append(100. * correct / total)

# Calculate metrics
# Update TP, FP, and FN counters
TP += ((predicted == 1) & (label == 1)).sum().item()
FP += ((predicted == 1) & (label == 0)).sum().item()
FN += ((predicted == 0) & (label == 1)).sum().item()
# Update tp_val, fp_val, and fn_val counters
tp_val += ((predicted == 1) & (label == 1)).sum().item()
fp_val += ((predicted == 1) & (label == 0)).sum().item()
fn_val += ((predicted == 0) & (label == 1)).sum().item()

kbar.update(batch_idx, values=[("loss", running_loss/(batch_idx+1)),
("acc", 100. * correct / total)])
Expand All @@ -150,8 +150,8 @@ def train_model(model: nn.Module, train_dataloader: IterableDataset, validation_

# Calculate average metrics for the epoch
# Calculate precision, recall, and F1 score
precision = TP / (TP + FP) if TP + FP > 0 else 0
recall = TP / (TP + FN) if TP + FN > 0 else 0
precision = tp_val / (tp_val + fp_val) if tp_val + fp_val > 0 else 0
recall = tp_val / (tp_val + fn_val) if tp_val + fn_val > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

wandb.log({"train_recall": recall, "train_precision": precision, "train_f1": f1_score})
Expand Down

0 comments on commit 34b51cb

Please sign in to comment.