Skip to content

Commit 52ccd35

Browse files
committed
fix: Binary segmentation foreground loss never evaluated
Signed-off-by: ytl0623 <[email protected]>
1 parent ea8a6ee commit 52ccd35

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6464
else:
6565
y_pred = torch.sigmoid(y_pred)
6666

67+
if y_pred.shape[1] == 1:
68+
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
69+
y_true = torch.cat([1 - y_true, y_true], dim=1)
70+
6771
n_pred_ch = y_pred.shape[1]
6872

6973
if self.to_onehot_y:
@@ -77,7 +81,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7781

7882
# clip the prediction to avoid NaN
7983
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
80-
8184
axis = list(range(2, len(y_pred.shape)))
8285

8386
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -86,18 +89,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8689
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
8790
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
8891

89-
# Calculate losses separately for each class, enhancing both classes
92+
# Class 0 is Background
9093
back_dice = 1 - dice_class[:, 0]
9194

92-
if n_pred_ch > 1:
93-
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
95+
# Class 1+ is Foreground
96+
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
9497

95-
if fore_dice.shape[1] > 1:
96-
fore_dice = torch.mean(fore_dice, dim=1)
97-
else:
98-
fore_dice = fore_dice.squeeze(1)
98+
if fore_dice.shape[1] > 1:
99+
fore_dice = torch.mean(fore_dice, dim=1)
99100
else:
100-
fore_dice = torch.zeros_like(back_dice)
101+
fore_dice = fore_dice.squeeze(1)
101102

102103
# Average class scores
103104
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
@@ -149,6 +150,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
149150
y_log_pred = F.logsigmoid(y_pred)
150151
y_pred = torch.sigmoid(y_pred)
151152

153+
if y_pred.shape[1] == 1:
154+
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
155+
y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))
156+
y_true = torch.cat([1 - y_true, y_true], dim=1)
157+
152158
n_pred_ch = y_pred.shape[1]
153159

154160
if self.to_onehot_y:
@@ -163,19 +169,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
163169
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
164170
cross_entropy = -y_true * y_log_pred
165171

172+
# Class 0: Background
166173
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
167174
back_ce = (1 - self.delta) * back_ce
168175

169-
if n_pred_ch > 1:
170-
fore_ce = cross_entropy[:, 1:]
171-
fore_ce = self.delta * fore_ce
176+
# Class 1+: Foreground
177+
fore_ce = cross_entropy[:, 1:]
178+
fore_ce = self.delta * fore_ce
172179

173-
if fore_ce.shape[1] > 1:
174-
fore_ce = torch.sum(fore_ce, dim=1)
175-
else:
176-
fore_ce = fore_ce.squeeze(1)
180+
if fore_ce.shape[1] > 1:
181+
fore_ce = torch.sum(fore_ce, dim=1)
177182
else:
178-
fore_ce = torch.zeros_like(back_ce)
183+
fore_ce = fore_ce.squeeze(1)
179184

180185
loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1))
181186
return loss

0 commit comments

Comments
 (0)