@@ -192,48 +192,48 @@ class AsymmetricUnifiedFocalLoss(_Loss):
192192
193193 def __init__ (
194194 self ,
195- include_background : bool = True ,
196195 to_onehot_y : bool = False ,
197- sigmoid : bool = False ,
198- softmax : bool = False ,
196+ use_sigmoid : bool = False ,
197+ use_softmax : bool = False ,
199198 lambda_focal : float = 0.5 ,
200199 focal_loss_gamma : float = 2.0 ,
201200 focal_loss_delta : float = 0.7 ,
202201 tversky_loss_gamma : float = 0.75 ,
203202 tversky_loss_delta : float = 0.7 ,
203+ include_background : bool = True ,
204204 reduction : LossReduction | str = LossReduction .MEAN ,
205205 ):
206206 """
207207 Args:
208- include_background: whether to include loss computation for the background class. Defaults to True.
209208 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
210- sigmoid : if True, apply a sigmoid activation to the input y_pred.
211- softmax : if True, apply a softmax activation to the input y_pred.
209+ use_sigmoid : if True, apply a sigmoid activation to the input y_pred.
210+ use_softmax : if True, apply a softmax activation to the input y_pred.
212211 lambda_focal: the weight for AsymmetricFocalLoss (Cross-Entropy based).
213212 The weight for AsymmetricFocalTverskyLoss will be (1 - lambda_focal). Defaults to 0.5.
214213 focal_loss_gamma: gamma parameter for the AsymmetricFocalLoss component. Defaults to 2.0.
215214 focal_loss_delta: delta parameter for the AsymmetricFocalLoss component. Defaults to 0.7.
216215 tversky_loss_gamma: gamma parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.75.
217216 tversky_loss_delta: delta parameter for the AsymmetricFocalTverskyLoss component. Defaults to 0.7.
217+ include_background: whether to include loss computation for the background class. Defaults to True.
218218 reduction: specifies the reduction to apply to the output: "none", "mean", "sum".
219219
220220 Example:
221221 >>> import torch
222222 >>> from monai.losses import AsymmetricUnifiedFocalLoss
223223 >>> pred = torch.randn((1, 2, 32, 32), dtype=torch.float32)
224224 >>> grnd = torch.randint(0, 2, (1, 1, 32, 32), dtype=torch.int64)
225- >>> fl = AsymmetricUnifiedFocalLoss(softmax =True, to_onehot_y=True)
225+ >>> fl = AsymmetricUnifiedFocalLoss(use_softmax =True, to_onehot_y=True)
226226 >>> fl(pred, grnd)
227227 """
228228 super ().__init__ (reduction = LossReduction (reduction ).value )
229- self .include_background = include_background
230229 self .to_onehot_y = to_onehot_y
231- self .sigmoid = sigmoid
232- self .softmax = softmax
230+ self .use_sigmoid = use_sigmoid
231+ self .use_softmax = use_softmax
233232 self .lambda_focal = lambda_focal
233+ self .include_background = include_background
234234
235- if sigmoid and softmax :
236- raise ValueError ("Both sigmoid and softmax cannot be True." )
235+ if self . use_sigmoid and self . use_softmax :
236+ raise ValueError ("Both use_sigmoid and use_softmax cannot be True." )
237237
238238 self .asy_focal_loss = AsymmetricFocalLoss (
239239 include_background = self .include_background ,
@@ -257,18 +257,18 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
257257 n_pred_ch = y_pred .shape [1 ]
258258
259259 y_pred_act = y_pred
260- if self .sigmoid :
260+ if self .use_sigmoid :
261261 y_pred_act = torch .sigmoid (y_pred )
262- elif self .softmax :
262+ elif self .use_softmax :
263263 if n_pred_ch == 1 :
264- warnings .warn ("single channel prediction, softmax =True ignored." )
264+ warnings .warn ("single channel prediction, use_softmax =True ignored." )
265265 else :
266266 y_pred_act = torch .softmax (y_pred , dim = 1 )
267267
268268 if self .to_onehot_y :
269- if n_pred_ch == 1 and not self .sigmoid :
269+ if n_pred_ch == 1 and not self .use_sigmoid :
270270 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
271- elif n_pred_ch > 1 or self .sigmoid :
271+ elif n_pred_ch > 1 or self .use_sigmoid :
272272 # Ensure y_true is (B, 1, H, W, [D]) for one-hot conversion
273273 if y_true .shape [1 ] != 1 :
274274 y_true = y_true .unsqueeze (1 )
0 commit comments