Skip to content

Commit 8ea09e5

Browse files
committed
fix None gradients
1 parent bcf8543 commit 8ea09e5

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

deeprvat/deeprvat/models_anngeno.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,25 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
156156
:raises RuntimeError: If NaNs are found in the training loss.
157157
"""
158158
# calls DeepSet.forward()
159-
y_pred = self(batch) # n_samples x n_phenos
160-
results = dict()
159+
y_pred = self(batch).flatten() # shape (n_samples * n_phenos, )
160+
y_true = batch["phenotypes"].flatten() # shape (n_samples x n_phenos, )
161+
non_nan_mask = ~y_true.isnan()
162+
n_non_nan = non_nan_mask.sum()
163+
161164
# for all metrics we want to evaluate (specified in config)
162-
y_true = batch["phenotypes"]
165+
results = dict()
163166
for name, fn in self.metric_fns.items():
164167
# compute mean loss across samples and phenotypes
165168
# ignore loss where true phenotype is NaN
166-
unreduced_loss = torch.where(y_true.isnan(), 0, fn(y_pred, y_true))
167-
results[name] = unreduced_loss.sum() / (~y_true.isnan()).sum()
169+
if n_non_nan == 0:
170+
logger.warning("All target values are NaN in this step")
171+
results[name] = 0
172+
else:
173+
unreduced_loss = fn(y_pred[non_nan_mask], y_true[non_nan_mask])
174+
results[name] = unreduced_loss.sum() / n_non_nan if n_non_nan > 0 else 0
175+
168176
self.log(f"train_{name}", results[name])
177+
169178
# set loss from which we compute backward passes
170179
loss = results[self.hparams.metrics["loss"]]
171180
if torch.isnan(loss):

0 commit comments

Comments
 (0)