Skip to content

Commit

Permalink
fix wandb solutions plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Jan 14, 2025
1 parent e1da4cf commit 868badc
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def train(lr: float = 0.001, batch_size: int = 32, epochs: int = 5) -> None:
plot_chance_level=(class_id == 2),
)

# alternatively use wandb.log({"roc": wandb.Image(plt)}
wandb.plot({"roc": plt})
# alternatively the wandb.plot.roc_curve function can be used
plt.close() # close the plot to avoid memory leaks and overlapping figures


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def train(lr: float = 0.001, batch_size: int = 32, epochs: int = 5) -> None:
plot_chance_level=(class_id == 2),
)

# alternatively use wandb.log({"roc": wandb.Image(plt)}
wandb.plot({"roc": plt})
# alternatively the wandb.plot.roc_curve function can be used
plt.close() # close the plot to avoid memory leaks and overlapping figures

final_accuracy = accuracy_score(targets, preds.argmax(dim=1))
final_precision = precision_score(targets, preds.argmax(dim=1), average="weighted")
Expand Down

0 comments on commit 868badc

Please sign in to comment.