Skip to content

Commit b7a5013

Browse files
committed
Validation logic may reject valid inputs.
Signed-off-by: ytl0623 <[email protected]>
1 parent c0e9d78 commit b7a5013

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
248248
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
249249
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
250250

251-
if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1:
251+
if y_true.shape[1] == self.num_classes:
252+
if not torch.all((y_true == 0) | (y_true == 1)):
253+
raise ValueError(f"y_true appears to be one-hot but contains values other than 0 and 1")
254+
elif y_true.shape[1] == 1:
255+
if torch.max(y_true) >= self.num_classes:
256+
raise ValueError(
257+
f"y_true labels must be in [0, {self.num_classes - 1}], but got max {torch.max(y_true)}"
258+
)
259+
else:
252260
raise ValueError(
253-
f"y_true must have {self.num_classes} channels (one-hot) or label values in [0, {self.num_classes - 1}], "
254-
f"but got shape {y_true.shape} with max value {torch.max(y_true)}"
261+
f"y_true must have {self.num_classes} channels (one-hot) or 1 channel (labels), got {y_true.shape[1]}"
255262
)
256263

257264
n_pred_ch = y_pred.shape[1]

0 commit comments

Comments
 (0)