Skip to content

Commit 8731b30

Browse files
committed
Simplify algebraic expression to avoid numerical instability
Signed-off-by: ytl0623 <[email protected]>
1 parent d6e4335 commit 8731b30

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
If True, softmax is used. If False, sigmoid is used. Defaults to False.
5050
delta : weight of the background. Defaults to 0.7.
5151
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
52-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
52+
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
5353
"""
5454
super().__init__(reduction=LossReduction(reduction).value)
5555
self.to_onehot_y = to_onehot_y
@@ -88,12 +88,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8888

8989
# Calculate losses separately for each class, enhancing both classes
9090
back_dice = 1 - dice_class[:, 0]
91-
fore_dice = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma)
9291

93-
if fore_dice.shape[1] > 1:
94-
fore_dice = torch.mean(fore_dice, dim=1)
92+
if n_pred_ch > 1:
93+
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 - self.gamma)
94+
95+
if fore_dice.shape[1] > 1:
96+
fore_dice = torch.mean(fore_dice, dim=1)
97+
else:
98+
fore_dice = fore_dice.squeeze(1)
9599
else:
96-
fore_dice = fore_dice.squeeze(1)
100+
fore_dice = torch.zeros_like(back_dice)
97101

98102
# Average class scores
99103
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
@@ -128,7 +132,7 @@ def __init__(
128132
If True, softmax is used. If False, sigmoid is used. Defaults to False.
129133
delta : weight of the background. Defaults to 0.7.
130134
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
131-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
135+
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
132136
"""
133137
super().__init__(reduction=LossReduction(reduction).value)
134138
self.to_onehot_y = to_onehot_y
@@ -198,8 +202,8 @@ def __init__(
198202
"""
199203
Args:
200204
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
201-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
202-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
205+
weight : weight for each loss function. Defaults to 0.5.
206+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
203207
delta : weight of the background. Defaults to 0.7.
204208
use_softmax: whether to use softmax to transform the original logits into probabilities.
205209
If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -235,7 +239,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
235239
236240
Raises:
237241
ValueError: When input and target are different shape
238-
ValueError: When the number of classes entered does not match the expected number
239242
"""
240243
if y_pred.shape != y_true.shape:
241244
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

0 commit comments

Comments
 (0)