@@ -156,16 +156,25 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
156
156
:raises RuntimeError: If NaNs are found in the training loss.
157
157
"""
158
158
# calls DeepSet.forward()
159
- y_pred = self (batch ) # n_samples x n_phenos
160
- results = dict ()
159
+ y_pred = self (batch ).flatten () # shape (n_samples * n_phenos, )
160
+ y_true = batch ["phenotypes" ].flatten () # shape (n_samples x n_phenos, )
161
+ non_nan_mask = ~ y_true .isnan ()
162
+ n_non_nan = non_nan_mask .sum ()
163
+
161
164
# for all metrics we want to evaluate (specified in config)
162
- y_true = batch [ "phenotypes" ]
165
+ results = dict ()
163
166
for name , fn in self .metric_fns .items ():
164
167
# compute mean loss across samples and phenotypes
165
168
# ignore loss where true phenotype is NaN
166
- unreduced_loss = torch .where (y_true .isnan (), 0 , fn (y_pred , y_true ))
167
- results [name ] = unreduced_loss .sum () / (~ y_true .isnan ()).sum ()
169
+ if n_non_nan == 0 :
170
+ logger .warning ("All target values are NaN in this step" )
171
+ results [name ] = 0
172
+ else :
173
+ unreduced_loss = fn (y_pred [non_nan_mask ], y_true [non_nan_mask ])
174
+ results [name ] = unreduced_loss .sum () / n_non_nan if n_non_nan > 0 else 0
175
+
168
176
self .log (f"train_{ name } " , results [name ])
177
+
169
178
# set loss from which we compute backward passes
170
179
loss = results [self .hparams .metrics ["loss" ]]
171
180
if torch .isnan (loss ):
0 commit comments