From e8292d9dbd72158c90bbbae96e3a3df8b941279c Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Wed, 25 Sep 2024 12:59:44 +0100 Subject: [PATCH 1/3] add flushing of val epoch resluts --- pvnet/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 8665bf3c..e7283b92 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -1,4 +1,3 @@ -"""Base model for all PVNet submodels""" import json import logging import os @@ -743,6 +742,7 @@ def on_validation_epoch_end(self): print("Failed to log validation results to wandb") print(e) + self.validation_epoch_results = [] horizon_maes_dict = self._horizon_maes.flush() # Create the horizon accuracy curve From efbd95cce67c2ceee5e497e5548c6e53d2021146 Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:08:00 +0100 Subject: [PATCH 2/3] add docstring back base_model.py --- pvnet/models/base_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index e7283b92..630c4aae 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -1,3 +1,4 @@ +"""Base model for all PVNet submodels""" import json import logging import os From ef0dd3f07021dd3e8dd6efb99fecf89c4117b3af Mon Sep 17 00:00:00 2001 From: Alexandra Udaltsova <43303448+AUdaltsova@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:59:38 +0100 Subject: [PATCH 3/3] log to csv only when end of accumbatch --- pvnet/models/base_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 630c4aae..83e67e7b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -666,7 +666,8 @@ def validation_step(self, batch: dict, batch_idx): # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] - self._log_validation_results(batch, y_hat, accum_batch_num) + if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: + self._log_validation_results(batch, y_hat, accum_batch_num) # Expand persistence to be the same shape as y losses = self._calculate_common_losses(y, y_hat)