Skip to content

Commit 0a29b5e

Browse files
committed
deleta use_sigmoid parameter
Signed-off-by: ytl0623 <[email protected]>
1 parent 0883c9c commit 0a29b5e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,17 +163,19 @@ def __init__(
163163
delta: float = 0.7,
164164
reduction: LossReduction | str = LossReduction.MEAN,
165165
include_background: bool = True,
166-
sigmoid: bool = False,
167-
softmax: bool = False,
166+
use_softmax: bool = False
168167
):
169168
"""
170169
Args:
171170
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
172171
num_classes : number of classes, it only supports 2 now. Defaults to 2.
172+
weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
173+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
173174
delta : weight of the background. Defaults to 0.7.
174-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
175-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
176-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
175+
reduction : reduction mode for the loss. Defaults to MEAN.
176+
include_background : whether to include the background class in loss calculation. Defaults to True.
177+
use_softmax: whether to use softmax to transform the original logits into probabilities.
178+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
177179
178180
Example:
179181
>>> import torch
@@ -184,6 +186,8 @@ def __init__(
184186
>>> fl(pred, grnd)
185187
"""
186188
super().__init__(reduction=LossReduction(reduction).value)
189+
if use_sigmoid and use_softmax:
190+
raise ValueError("use_sigmoid and use_softmax are mutually exclusive; only one can be True.")
187191
self.to_onehot_y = to_onehot_y
188192
self.num_classes = num_classes
189193
self.gamma = gamma
@@ -192,8 +196,7 @@ def __init__(
192196
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
193197
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
194198
self.include_background = include_background
195-
self.sigmoid = sigmoid
196-
self.softmax = softmax
199+
self.use_softmax = use_softmax
197200

198201
# TODO: Implement this function to support multiple classes segmentation
199202
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)