Skip to content

Commit 32540ce

Browse files
committed
Refactor parameters for UnifiedFocalLoss class
Signed-off-by: ytl0623 <[email protected]
1 parent 0d913eb commit 32540ce

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,48 +192,48 @@ class AsymmetricUnifiedFocalLoss(_Loss):
192192

193193
def __init__(
194194
self,
195-
include_background: bool = True,
196195
to_onehot_y: bool = False,
197-
sigmoid: bool = False,
198-
softmax: bool = False,
196+
use_sigmoid: bool = False,
197+
use_softmax: bool = False,
199198
lambda_focal: float = 0.5,
200199
focal_loss_gamma: float = 2.0,
201200
focal_loss_delta: float = 0.7,
202201
tversky_loss_gamma: float = 0.75,
203202
tversky_loss_delta: float = 0.7,
203+
include_background: bool = True,
204204
reduction: LossReduction | str = LossReduction.MEAN,
205205
):
206206
"""
207207
Args:
208-
include_background: whether to include loss computation for the background class. Defaults to True.
209208
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
210-
sigmoid: if True, apply a sigmoid activation to the input y_pred.
211-
softmax: if True, apply a softmax activation to the input y_pred.
209+
use_sigmoid: if True, apply a sigmoid activation to the input y_pred.
210+
use_softmax: if True, apply a softmax activation to the input y_pred.
212211
lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
213212
The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
214213
focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
215214
focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
216215
tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
217216
tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
217+
include_background: whether to include loss computation for the background class. Defaults to True.
218218
reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
219219
220220
Example:
221221
>>> import torch
222222
>>> from monai.losses import AsymmetricUnifiedFocalLoss
223223
>>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32)
224224
>>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
225-
>>> fl = AsymmetricUnifiedFocalLoss(softmax=True, to_onehot_y=True)
225+
>>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
226226
>>> fl(pred, grnd)
227227
"""
228228
super().__init__(reduction=LossReduction(reduction).value)
229-
self.include_background = include_background
230229
self.to_onehot_y = to_onehot_y
231-
self.sigmoid = sigmoid
232-
self.softmax = softmax
230+
self.use_sigmoid = use_sigmoid
231+
self.use_softmax = use_softmax
233232
self.lambda_focal = lambda_focal
233+
self.include_background = include_background
234234

235-
if sigmoid and softmax:
236-
raise ValueError("Both sigmoid and softmax cannot be True.")
235+
if self.use_sigmoid and self.use_softmax:
236+
raise ValueError("Both use_sigmoid and use_softmax cannot be True.")
237237

238238
self.asy_focal_loss = AsymmetricFocalLoss(
239239
include_background=self.include_background,
@@ -257,18 +257,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
257257
n_pred_ch = y_pred.shape[1]
258258

259259
y_pred_act = y_pred
260-
if self.sigmoid:
260+
if self.use_sigmoid:
261261
y_pred_act = torch.sigmoid(y_pred)
262-
elif self.softmax:
262+
elif self.use_softmax:
263263
if n_pred_ch == 1:
264-
warnings.warn("single channel prediction, softmax=True ignored.")
264+
warnings.warn("single channel prediction, use_softmax=True ignored.")
265265
else:
266266
y_pred_act = torch.softmax(y_pred, dim=1)
267267

268268
if self.to_onehot_y:
269-
if n_pred_ch == 1 and not self.sigmoid:
269+
if n_pred_ch == 1 and not self.use_sigmoid:
270270
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
271-
elif n_pred_ch > 1 or self.sigmoid:
271+
elif n_pred_ch > 1 or self.use_sigmoid:
272272
# Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
273273
if y_true.shape[1] != 1:
274274
y_true = y_true.unsqueeze(1)

0 commit comments

Comments
 (0)