@@ -163,17 +163,19 @@ def __init__(
163163 delta : float = 0.7 ,
164164 reduction : LossReduction | str = LossReduction .MEAN ,
165165 include_background : bool = True ,
166- sigmoid : bool = False ,
167- softmax : bool = False ,
166+ use_softmax : bool = False
168167 ):
169168 """
170169 Args:
171170 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
172171 num_classes : number of classes, it only supports 2 now. Defaults to 2.
172+ weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
173+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
173174 delta : weight of the background. Defaults to 0.7.
174- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
175- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
176- weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
175+ reduction : reduction mode for the loss. Defaults to MEAN.
176+ include_background : whether to include the background class in loss calculation. Defaults to True.
177+ use_softmax: whether to use softmax to transform the original logits into probabilities.
178+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
177179
178180 Example:
179181 >>> import torch
@@ -184,6 +186,8 @@ def __init__(
184186 >>> fl(pred, grnd)
185187 """
186188 super ().__init__ (reduction = LossReduction (reduction ).value )
189+ if use_sigmoid and use_softmax :
190+ raise ValueError ("use_sigmoid and use_softmax are mutually exclusive; only one can be True." )
187191 self .to_onehot_y = to_onehot_y
188192 self .num_classes = num_classes
189193 self .gamma = gamma
@@ -192,8 +196,7 @@ def __init__(
192196 self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
193197 self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
194198 self .include_background = include_background
195- self .sigmoid = sigmoid
196- self .softmax = softmax
199+ self .use_softmax = use_softmax
197200
198201 # TODO: Implement this function to support multiple classes segmentation
199202 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
0 commit comments