@@ -27,8 +27,6 @@ class AsymmetricFocalTverskyLoss(_Loss):
2727
2828 It supports both binary and multi-class segmentation.
2929
30- The logic assumes channel 0 is Background, and channels 1..N are Foreground.
31-
3230 Reimplementation of the Asymmetric Focal Tversky Loss described in:
3331
3432 - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
@@ -80,12 +78,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8078 # clip the prediction to avoid NaN
8179 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
8280
83- spatial_dims = list (range (2 , len (y_pred .shape )))
81+ axis = list (range (2 , len (y_pred .shape )))
8482
8583 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
86- tp = torch .sum (y_true * y_pred , dim = spatial_dims )
87- fn = torch .sum (y_true * (1 - y_pred ), dim = spatial_dims )
88- fp = torch .sum ((1 - y_true ) * y_pred , dim = spatial_dims )
84+ tp = torch .sum (y_true * y_pred , dim = axis )
85+ fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
86+ fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
8987 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
9088
9189 # Calculate losses separately for each class, enhancing both classes
@@ -200,32 +198,31 @@ def __init__(
200198 """
201199 Args:
202200 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
203- delta : weight of the background. Defaults to 0.7.
204- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
205- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
206201 weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
202+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
203+ delta : weight of the background. Defaults to 0.7.
207204 use_softmax: whether to use softmax to transform the original logits into probabilities.
208205 If True, softmax is used. If False, sigmoid is used. Defaults to False.
209206
210207 Example:
211208 >>> import torch
212209 >>> from monai.losses import AsymmetricUnifiedFocalLoss
213- >>> pred = torch.ones ((1,1, 32,32), dtype=torch.float32 )
214- >>> grnd = torch.ones( (1,1, 32,32), dtype=torch.int64 )
215- >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
210+ >>> pred = torch.randn ((1, 3, 32, 32))
211+ >>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
212+ >>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
216213 >>> fl(pred, grnd)
217214 """
218215 super ().__init__ (reduction = LossReduction (reduction ).value )
219216 self .to_onehot_y = to_onehot_y
217+ self .weight : float = weight
220218 self .gamma = gamma
221219 self .delta = delta
222- self .weight : float = weight
223220 self .use_softmax = use_softmax
224221 self .asy_focal_loss = AsymmetricFocalLoss (
225- gamma = self .gamma , delta = self .delta , use_softmax = use_softmax , to_onehot_y = to_onehot_y
222+ to_onehot_y = to_onehot_y , gamma = self .gamma , delta = self .delta , use_softmax = use_softmax
226223 )
227224 self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
228- gamma = self .gamma , delta = self .delta , use_softmax = use_softmax , to_onehot_y = to_onehot_y
225+ to_onehot_y = to_onehot_y , gamma = self .gamma , delta = self .delta , use_softmax = use_softmax
229226 )
230227
231228 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
0 commit comments