Skip to content

Commit 230394b

Browse files
committed
Relocate the sigmoid/softmax application to be conditional on the number of channels.
Signed-off-by: ytl0623 <[email protected]>
1 parent 912235e commit 230394b

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,20 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
253253
raise ValueError(
254254
f"Single-channel input only supported for binary (num_classes=2), got {self.num_classes}"
255255
)
256-
y_pred = torch.cat([torch.zeros_like(y_pred), y_pred], dim=1)
256+
257+
if self.use_softmax:
258+
raise ValueError("use_softmax=True is not compatible with single-channel input")
259+
260+
y_pred_sigmoid = torch.sigmoid(y_pred.float())
261+
y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1)
262+
257263
if y_true.shape[1] == 1:
258264
y_true = one_hot(y_true, num_classes=self.num_classes)
265+
else:
266+
if self.use_softmax:
267+
y_pred = torch.softmax(y_pred.float(), dim=1)
268+
else:
269+
y_pred = torch.sigmoid(y_pred.float())
259270

260271
if y_true.shape[1] != self.num_classes or torch.max(y_true) > self.num_classes - 1:
261272
raise ValueError(

0 commit comments

Comments
 (0)