Skip to content

Commit 836ce42

Browse files
committed
Enhanced AsymmetricUnifiedFocalLoss with Sigmoid/Softmax
Signed-off-by: ytl0623 <[email protected]>
1 parent 1a28917 commit 836ce42

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7979
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
8080

8181
# Calculate losses separately for each class, enhancing both classes
82-
back_dice = 1 - dice_class[:, 0]
83-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
82+
back_dice = 1 - dice_class[:, 0:1]
83+
fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)
8484

8585
if not self.include_background:
8686
back_dice = back_dice * 0.0
8787

88+
all_dice = torch.cat([back_dice, fore_dice], dim=1)
89+
8890
# Average class scores
89-
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
91+
loss = torch.mean(all_dice)
9092
return loss
9193

9294

@@ -141,16 +143,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
141143
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
142144
cross_entropy = -y_true * torch.log(y_pred)
143145

144-
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
146+
back_ce = torch.pow(1 - y_pred[:, 0:1], self.gamma) * cross_entropy[:, 0:1]
145147
back_ce = (1 - self.delta) * back_ce
146148

147-
fore_ce = cross_entropy[:, 1]
149+
fore_ce = cross_entropy[:, 1:]
148150
fore_ce = self.delta * fore_ce
149151

150152
if not self.include_background:
151153
back_ce = back_ce * 0.0
152154

153-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
155+
all_ce = torch.cat([back_ce, fore_ce], dim=1)
156+
157+
loss = torch.mean(torch.sum(all_ce, dim=1))
154158
return loss
155159

156160

0 commit comments

Comments
 (0)