Skip to content

Commit d6e4335

Browse files
committed
add files
Signed-off-by: ytl0623 <[email protected]>
1 parent d724a95 commit d6e4335

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ class AsymmetricFocalTverskyLoss(_Loss):
2727
2828
It supports both binary and multi-class segmentation.
2929
30-
The logic assumes channel 0 is Background, and channels 1..N are Foreground.
31-
3230
Reimplementation of the Asymmetric Focal Tversky Loss described in:
3331
3432
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
@@ -80,12 +78,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8078
# clip the prediction to avoid NaN
8179
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
8280

83-
spatial_dims = list(range(2, len(y_pred.shape)))
81+
axis = list(range(2, len(y_pred.shape)))
8482

8583
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
86-
tp = torch.sum(y_true * y_pred, dim=spatial_dims)
87-
fn = torch.sum(y_true * (1 - y_pred), dim=spatial_dims)
88-
fp = torch.sum((1 - y_true) * y_pred, dim=spatial_dims)
84+
tp = torch.sum(y_true * y_pred, dim=axis)
85+
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
86+
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
8987
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
9088

9189
# Calculate losses separately for each class, enhancing both classes
@@ -200,32 +198,31 @@ def __init__(
200198
"""
201199
Args:
202200
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
203-
delta : weight of the background. Defaults to 0.7.
204-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
205-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
206201
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
202+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
203+
delta : weight of the background. Defaults to 0.7.
207204
use_softmax: whether to use softmax to transform the original logits into probabilities.
208205
If True, softmax is used. If False, sigmoid is used. Defaults to False.
209206
210207
Example:
211208
>>> import torch
212209
>>> from monai.losses import AsymmetricUnifiedFocalLoss
213-
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
214-
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
215-
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
210+
>>> pred = torch.randn((1, 3, 32, 32))
211+
>>> grnd = torch.randint(0, 3, (1, 1, 32, 32))
212+
>>> fl = AsymmetricUnifiedFocalLoss(use_softmax=True, to_onehot_y=True)
216213
>>> fl(pred, grnd)
217214
"""
218215
super().__init__(reduction=LossReduction(reduction).value)
219216
self.to_onehot_y = to_onehot_y
217+
self.weight: float = weight
220218
self.gamma = gamma
221219
self.delta = delta
222-
self.weight: float = weight
223220
self.use_softmax = use_softmax
224221
self.asy_focal_loss = AsymmetricFocalLoss(
225-
gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y
222+
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
226223
)
227224
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
228-
gamma=self.gamma, delta=self.delta, use_softmax=use_softmax, to_onehot_y=to_onehot_y
225+
to_onehot_y=to_onehot_y, gamma=self.gamma, delta=self.delta, use_softmax=use_softmax
229226
)
230227

231228
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)