Skip to content

Commit ea8a6ee

Browse files
committed
minor fixes
Signed-off-by: ytl0623 <[email protected]>
1 parent 8731b30 commit ea8a6ee

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
use_softmax: whether to use softmax to transform the original logits into probabilities.
132132
If True, softmax is used. If False, sigmoid is used. Defaults to False.
133133
delta : weight of the background. Defaults to 0.7.
134-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
134+
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2.
135135
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
136136
"""
137137
super().__init__(reduction=LossReduction(reduction).value)
@@ -166,13 +166,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
166166
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
167167
back_ce = (1 - self.delta) * back_ce
168168

169-
fore_ce = cross_entropy[:, 1:]
170-
fore_ce = self.delta * fore_ce
169+
if n_pred_ch > 1:
170+
fore_ce = cross_entropy[:, 1:]
171+
fore_ce = self.delta * fore_ce
171172

172-
if fore_ce.shape[1] > 1:
173-
fore_ce = torch.sum(fore_ce, dim=1)
173+
if fore_ce.shape[1] > 1:
174+
fore_ce = torch.sum(fore_ce, dim=1)
175+
else:
176+
fore_ce = fore_ce.squeeze(1)
174177
else:
175-
fore_ce = fore_ce.squeeze(1)
178+
fore_ce = torch.zeros_like(back_ce)
176179

177180
loss = torch.mean(torch.stack([back_ce, fore_ce], dim=-1))
178181
return loss

0 commit comments

Comments
 (0)