Skip to content

Commit c27945a

Browse files
committed
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss
Signed-off-by: ytl0623 <[email protected]>
1 parent 15fd428 commit c27945a

File tree

1 file changed

+94
-140
lines changed

1 file changed

+94
-140
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 94 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
from __future__ import annotations
1313

1414
import warnings
15+
from collections.abc import Sequence
1516

1617
import torch
18+
import torch.nn.functional as F
1719
from torch.nn.modules.loss import _Loss
1820

21+
from monai.losses import FocalLoss
1922
from monai.networks import one_hot
2023
from monai.utils import LossReduction
2124

2225

2326
class AsymmetricFocalTverskyLoss(_Loss):
2427
"""
25-
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
26-
27-
Actually, it's only supported for binary image segmentation now.
28+
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which focuses on the foreground class.
2829
2930
Reimplementation of the Asymmetric Focal Tversky Loss described in:
3031
@@ -34,6 +35,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3435

3536
def __init__(
3637
self,
38+
include_background: bool = True,
3739
to_onehot_y: bool = False,
3840
delta: float = 0.7,
3941
gamma: float = 0.75,
@@ -42,18 +44,27 @@ def __init__(
4244
) -> None:
4345
"""
4446
Args:
47+
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
4548
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46-
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
49+
delta: weight of the background. Defaults to 0.7.
50+
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
51+
epsilon: a small number to avoid division by zero. Defaults to 1e-7.
52+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
53+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
4954
"""
5055
super().__init__(reduction=LossReduction(reduction).value)
56+
self.include_background = include_background
5157
self.to_onehot_y = to_onehot_y
5258
self.delta = delta
5359
self.gamma = gamma
5460
self.epsilon = epsilon
5561

5662
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
63+
"""
64+
Args:
65+
y_pred: the shape should be BNH[WD], where N is the number of classes.
66+
y_true: the shape should be BNH[WD], where N is the number of classes.
67+
"""
5768
n_pred_ch = y_pred.shape[1]
5869

5970
if self.to_onehot_y:
@@ -62,179 +73,122 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6273
else:
6374
y_true = one_hot(y_true, num_classes=n_pred_ch)
6475

76+
if not self.include_background:
77+
if n_pred_ch == 1:
78+
warnings.warn("single channel prediction, `include_background=False` ignored.")
79+
else:
80+
# if skipping background, removing first channel
81+
y_true = y_true[:, 1:]
82+
y_pred = y_pred[:, 1:]
83+
6584
if y_true.shape != y_pred.shape:
6685
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
6786

6887
# clip the prediction to avoid NaN
6988
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
70-
axis = list(range(2, len(y_pred.shape)))
7189

7290
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
91+
# Sum over spatial dimensions (B, C, H, W, D) -> (B, C)
92+
axis = list(range(2, len(y_pred.shape)))
7393
tp = torch.sum(y_true * y_pred, dim=axis)
7494
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
7595
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
76-
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
7796

78-
# Calculate losses separately for each class, enhancing both classes
79-
back_dice = 1 - dice_class[:, 0]
80-
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
81-
82-
# Average class scores
83-
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
84-
return loss
85-
86-
87-
class AsymmetricFocalLoss(_Loss):
88-
"""
89-
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90-
91-
Actually, it's only supported for binary image segmentation now.
92-
93-
Reimplementation of the Asymmetric Focal Loss described in:
94-
95-
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96-
Michael Yeung, Computerized Medical Imaging and Graphics
97-
"""
97+
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
9898

99-
def __init__(
100-
self,
101-
to_onehot_y: bool = False,
102-
delta: float = 0.7,
103-
gamma: float = 2,
104-
epsilon: float = 1e-7,
105-
reduction: LossReduction | str = LossReduction.MEAN,
106-
):
107-
"""
108-
Args:
109-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110-
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
113-
"""
114-
super().__init__(reduction=LossReduction(reduction).value)
115-
self.to_onehot_y = to_onehot_y
116-
self.delta = delta
117-
self.gamma = gamma
118-
self.epsilon = epsilon
99+
# Calculate losses separately for each class
100+
# Background (index 0) treated normally: 1 - dice
101+
# Foreground (index > 0) treated with focal modulation: (1 - dice)^(1-gamma)
119102

120-
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
121-
n_pred_ch = y_pred.shape[1]
103+
# Note: If include_background is False, index 0 is actually the first foreground class.
104+
# We generally apply the asymmetry between the FIRST channel and the REST.
105+
# However, for rigorous multi-class 'Asymmetric' implementation, we assume
106+
# class 0 is background (if included) and others are foreground.
122107

123-
if self.to_onehot_y:
124-
if n_pred_ch == 1:
125-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
108+
loss_list = []
109+
for i in range(y_pred.shape[1]):
110+
# If this is the background channel (index 0 and included), use standard Dice loss
111+
if i == 0 and self.include_background:
112+
loss_list.append(1 - dice_class[:, i])
126113
else:
127-
y_true = one_hot(y_true, num_classes=n_pred_ch)
128-
129-
if y_true.shape != y_pred.shape:
130-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
114+
# Foreground classes: apply focal modulation
115+
# Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
116+
loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma))
131117

132-
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
133-
cross_entropy = -y_true * torch.log(y_pred)
134-
135-
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
136-
back_ce = (1 - self.delta) * back_ce
118+
loss = torch.stack(loss_list, dim=-1)
137119

138-
fore_ce = cross_entropy[:, 1]
139-
fore_ce = self.delta * fore_ce
120+
if self.reduction == LossReduction.SUM.value:
121+
return loss.sum()
122+
if self.reduction == LossReduction.NONE.value:
123+
return loss
124+
if self.reduction == LossReduction.MEAN.value:
125+
return loss.mean()
140126

141-
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
142-
return loss
127+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
143128

144129

145130
class AsymmetricUnifiedFocalLoss(_Loss):
146131
"""
147-
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148-
149-
Actually, it's only supported for binary image segmentation now
150-
151-
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152-
153-
- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154-
Michael Yeung, Computerized Medical Imaging and Graphics
132+
AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Focal Loss and
133+
Asymmetric Focal Tversky Loss to handle class imbalance.
155134
"""
156135

157136
def __init__(
158137
self,
159-
to_onehot_y: bool = False,
160-
num_classes: int = 2,
161138
weight: float = 0.5,
139+
delta: float = 0.6,
162140
gamma: float = 0.5,
163-
delta: float = 0.7,
141+
include_background: bool = True,
142+
to_onehot_y: bool = False,
143+
use_softmax: bool = False,
164144
reduction: LossReduction | str = LossReduction.MEAN,
165-
):
145+
) -> None:
166146
"""
167147
Args:
168-
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170-
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
174-
175-
Example:
176-
>>> import torch
177-
>>> from monai.losses import AsymmetricUnifiedFocalLoss
178-
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179-
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180-
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
181-
>>> fl(pred, grnd)
148+
weight: The weighting factor 'lambda' between Focal Loss and Asymmetric Focal Tversky Loss.
149+
delta: weight of the background class (used in Tversky). Defaults to 0.6.
150+
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
151+
include_background: if False, channel index 0 (background category) is excluded.
152+
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
153+
use_softmax: whether to use softmax to transform the original logits into probabilities.
154+
If True, softmax is used. If False, sigmoid is used. Defaults to False.
155+
reduction: Specifies the reduction to apply to the output. Defaults to ``"mean"``.
182156
"""
183157
super().__init__(reduction=LossReduction(reduction).value)
184-
self.to_onehot_y = to_onehot_y
185-
self.num_classes = num_classes
186-
self.gamma = gamma
187-
self.delta = delta
188-
self.weight: float = weight
189-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
190-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
158+
self.weight = weight
159+
self.use_softmax = use_softmax # 儲存參數
160+
161+
self.focal_loss = FocalLoss(
162+
include_background=include_background,
163+
to_onehot_y=to_onehot_y,
164+
gamma=gamma,
165+
reduction=reduction,
166+
use_softmax=use_softmax,
167+
)
168+
169+
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
170+
include_background=include_background,
171+
to_onehot_y=to_onehot_y,
172+
delta=delta,
173+
gamma=gamma,
174+
reduction=reduction,
175+
)
191176

192-
# TODO: Implement this function to support multiple classes segmentation
193177
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
194178
"""
195179
Args:
196-
y_pred : the shape should be BNH[WD], where N is the number of classes.
197-
It only supports binary segmentation.
198-
The input should be the original logits since it will be transformed by
199-
a sigmoid in the forward function.
200-
y_true : the shape should be BNH[WD], where N is the number of classes.
201-
It only supports binary segmentation.
202-
203-
Raises:
204-
ValueError: When input and target are different shape
205-
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206-
ValueError: When num_classes
207-
ValueError: When the number of classes entered does not match the expected number
180+
y_pred: (BNH[WD]) Logits (raw scores).
181+
y_true: (BNH[WD]) Ground truth labels.
208182
"""
209-
if y_pred.shape != y_true.shape:
210-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
211-
212-
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
213-
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
214-
215-
if y_pred.shape[1] == 1:
216-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
217-
y_true = one_hot(y_true, num_classes=self.num_classes)
183+
focal_loss = self.focal_loss(y_pred, y_true)
218184

219-
if torch.max(y_true) != self.num_classes - 1:
220-
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
221-
222-
n_pred_ch = y_pred.shape[1]
223-
if self.to_onehot_y:
224-
if n_pred_ch == 1:
225-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
226-
else:
227-
y_true = one_hot(y_true, num_classes=n_pred_ch)
185+
if self.use_softmax:
186+
y_pred_prob = torch.softmax(y_pred, dim=1)
187+
else:
188+
y_pred_prob = torch.sigmoid(y_pred)
228189

229-
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
230-
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
190+
tversky_loss = self.asy_focal_tversky_loss(y_pred_prob, y_true)
231191

232-
loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
192+
loss = self.weight * focal_loss + (1 - self.weight) * tversky_loss
233193

234-
if self.reduction == LossReduction.SUM.value:
235-
return torch.sum(loss) # sum over the batch and channel dims
236-
if self.reduction == LossReduction.NONE.value:
237-
return loss # returns [N, num_classes] losses
238-
if self.reduction == LossReduction.MEAN.value:
239-
return torch.mean(loss)
240-
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
194+
return loss

0 commit comments

Comments
 (0)