Skip to content

Commit e0e48a3

Browse files
committed
minor fixes
Signed-off-by: ytl0623 <[email protected]>
1 parent 6c25189 commit e0e48a3

File tree

1 file changed

+41
-31
lines changed

1 file changed

+41
-31
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
use_softmax: whether to use softmax to transform the original logits into probabilities.
4949
If True, softmax is used. If False, sigmoid is used. Defaults to False.
5050
delta : weight of the background. Defaults to 0.7.
51-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
52-
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
51+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
52+
epsilon : stability factor used to avoid division by zero. Defaults to 1e-7.
5353
"""
5454
super().__init__(reduction=LossReduction(reduction).value)
5555
self.to_onehot_y = to_onehot_y
@@ -59,6 +59,17 @@ def __init__(
5959
self.epsilon = epsilon
6060

6161
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
62+
n_pred_ch = y_pred.shape[1]
63+
64+
if self.use_softmax and n_pred_ch == 1:
65+
raise ValueError("single channel prediction with `use_softmax=True` is not allowed.")
66+
67+
if self.to_onehot_y:
68+
if n_pred_ch == 1:
69+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
70+
else:
71+
y_true = one_hot(y_true, num_classes=n_pred_ch)
72+
6273
if self.use_softmax:
6374
y_pred = torch.softmax(y_pred, dim=1)
6475
else:
@@ -68,17 +79,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6879
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
6980
y_true = torch.cat([1 - y_true, y_true], dim=1)
7081

71-
n_pred_ch = y_pred.shape[1]
72-
73-
if self.to_onehot_y:
74-
if n_pred_ch == 1:
75-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
76-
else:
77-
y_true = one_hot(y_true, num_classes=n_pred_ch)
78-
7982
if y_true.shape != y_pred.shape:
8083
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
8184

85+
# Calculate Loss
8286
axis = list(range(2, len(y_pred.shape)))
8387

8488
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -130,8 +134,8 @@ def __init__(
130134
use_softmax: whether to use softmax to transform the original logits into probabilities.
131135
If True, softmax is used. If False, sigmoid is used. Defaults to False.
132136
delta : weight of the background. Defaults to 0.7.
133-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2.
134-
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
137+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
138+
epsilon : stability factor used to avoid division by zero. Defaults to 1e-7.
135139
"""
136140
super().__init__(reduction=LossReduction(reduction).value)
137141
self.to_onehot_y = to_onehot_y
@@ -141,26 +145,34 @@ def __init__(
141145
self.epsilon = epsilon
142146

143147
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
148+
n_pred_ch = y_pred.shape[1]
149+
150+
if self.use_softmax and n_pred_ch == 1:
151+
raise ValueError("single channel prediction with `use_softmax=True` is not allowed.")
152+
153+
if self.to_onehot_y:
154+
if n_pred_ch == 1:
155+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
156+
else:
157+
y_true = one_hot(y_true, num_classes=n_pred_ch)
158+
159+
# Save logits for numerical stability in single-channel expansion
160+
y_logits = y_pred
161+
144162
if self.use_softmax:
145163
y_log_pred = F.log_softmax(y_pred, dim=1)
146164
y_pred = torch.exp(y_log_pred)
147165
else:
148166
y_log_pred = F.logsigmoid(y_pred)
149167
y_pred = torch.sigmoid(y_pred)
150168

151-
if y_pred.shape[1] == 1:
169+
# Handle Single Channel (Binary) Expansion
170+
if n_pred_ch == 1:
152171
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
153-
y_log_pred = torch.log(torch.clamp(y_pred, 1e-7, 1.0))
172+
bg_log_pred = F.logsigmoid(-y_logits)
173+
y_log_pred = torch.cat([bg_log_pred, y_log_pred], dim=1)
154174
y_true = torch.cat([1 - y_true, y_true], dim=1)
155175

156-
n_pred_ch = y_pred.shape[1]
157-
158-
if self.to_onehot_y:
159-
if n_pred_ch == 1:
160-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
161-
else:
162-
y_true = one_hot(y_true, num_classes=n_pred_ch)
163-
164176
if y_true.shape != y_pred.shape:
165177
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
166178

@@ -199,20 +211,18 @@ class AsymmetricUnifiedFocalLoss(_Loss):
199211
def __init__(
200212
self,
201213
to_onehot_y: bool = False,
202-
weight: float = 0.5,
203-
gamma: float = 0.5,
204-
delta: float = 0.7,
205214
use_softmax: bool = False,
215+
delta: float = 0.7,
216+
gamma: float = 2,
206217
reduction: LossReduction | str = LossReduction.MEAN,
207218
):
208219
"""
209220
Args:
210-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
211-
weight : weight for each loss function. Defaults to 0.5.
212-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
213-
delta : weight of the background. Defaults to 0.7.
221+
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
214222
use_softmax: whether to use softmax to transform the original logits into probabilities.
215223
If True, softmax is used. If False, sigmoid is used. Defaults to False.
224+
delta : weight of the background. Defaults to 0.7.
225+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
216226
217227
Example:
218228
>>> import torch
@@ -250,9 +260,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
250260
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
251261

252262
if self.reduction == LossReduction.SUM.value:
253-
return torch.sum(loss) # sum over the batch and channel dims
263+
return torch.sum(loss)
254264
if self.reduction == LossReduction.NONE.value:
255-
return loss # returns [N, num_classes] losses
265+
return loss
256266
if self.reduction == LossReduction.MEAN.value:
257267
return torch.mean(loss)
258268
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

0 commit comments

Comments
 (0)