1212from __future__ import annotations
1313
1414import warnings
15+ from collections .abc import Sequence
1516
1617import torch
18+ import torch .nn .functional as F
1719from torch .nn .modules .loss import _Loss
1820
21+ from monai .losses import FocalLoss
1922from monai .networks import one_hot
2023from monai .utils import LossReduction
2124
2225
2326class AsymmetricFocalTverskyLoss (_Loss ):
2427 """
25- AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
26-
27- Actually, it's only supported for binary image segmentation now.
28+ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which focuses on the foreground class.
2829
2930 Reimplementation of the Asymmetric Focal Tversky Loss described in:
3031
@@ -34,6 +35,7 @@ class AsymmetricFocalTverskyLoss(_Loss):
3435
3536 def __init__ (
3637 self ,
38+ include_background : bool = True ,
3739 to_onehot_y : bool = False ,
3840 delta : float = 0.7 ,
3941 gamma : float = 0.75 ,
@@ -42,18 +44,27 @@ def __init__(
4244 ) -> None :
4345 """
4446 Args:
47+ include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
4548 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
46- delta : weight of the background. Defaults to 0.7.
47- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
48- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
49+ delta: weight of the background. Defaults to 0.7.
50+ gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
51+ epsilon: a small number to avoid division by zero. Defaults to 1e-7.
52+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
53+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
4954 """
5055 super ().__init__ (reduction = LossReduction (reduction ).value )
56+ self .include_background = include_background
5157 self .to_onehot_y = to_onehot_y
5258 self .delta = delta
5359 self .gamma = gamma
5460 self .epsilon = epsilon
5561
5662 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
63+ """
64+ Args:
65+ y_pred: the shape should be BNH[WD], where N is the number of classes.
66+ y_true: the shape should be BNH[WD], where N is the number of classes.
67+ """
5768 n_pred_ch = y_pred .shape [1 ]
5869
5970 if self .to_onehot_y :
@@ -62,179 +73,122 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6273 else :
6374 y_true = one_hot (y_true , num_classes = n_pred_ch )
6475
76+ if not self .include_background :
77+ if n_pred_ch == 1 :
78+ warnings .warn ("single channel prediction, `include_background=False` ignored." )
79+ else :
80+ # if skipping background, removing first channel
81+ y_true = y_true [:, 1 :]
82+ y_pred = y_pred [:, 1 :]
83+
6584 if y_true .shape != y_pred .shape :
6685 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
6786
6887 # clip the prediction to avoid NaN
6988 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
70- axis = list (range (2 , len (y_pred .shape )))
7189
7290 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
91+ # Sum over spatial dimensions (B, C, H, W, D) -> (B, C)
92+ axis = list (range (2 , len (y_pred .shape )))
7393 tp = torch .sum (y_true * y_pred , dim = axis )
7494 fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
7595 fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
76- dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
7796
78- # Calculate losses separately for each class, enhancing both classes
79- back_dice = 1 - dice_class [:, 0 ]
80- fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
81-
82- # Average class scores
83- loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
84- return loss
85-
86-
87- class AsymmetricFocalLoss (_Loss ):
88- """
89- AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
90-
91- Actually, it's only supported for binary image segmentation now.
92-
93- Reimplementation of the Asymmetric Focal Loss described in:
94-
95- - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
96- Michael Yeung, Computerized Medical Imaging and Graphics
97- """
97+ dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
9898
99- def __init__ (
100- self ,
101- to_onehot_y : bool = False ,
102- delta : float = 0.7 ,
103- gamma : float = 2 ,
104- epsilon : float = 1e-7 ,
105- reduction : LossReduction | str = LossReduction .MEAN ,
106- ):
107- """
108- Args:
109- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110- delta : weight of the background. Defaults to 0.7.
111- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
113- """
114- super ().__init__ (reduction = LossReduction (reduction ).value )
115- self .to_onehot_y = to_onehot_y
116- self .delta = delta
117- self .gamma = gamma
118- self .epsilon = epsilon
99+ # Calculate losses separately for each class
100+ # Background (index 0) treated normally: 1 - dice
101+ # Foreground (index > 0) treated with focal modulation: (1 - dice)^(1-gamma)
119102
120- def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
121- n_pred_ch = y_pred .shape [1 ]
103+ # Note: If include_background is False, index 0 is actually the first foreground class.
104+ # We generally apply the asymmetry between the FIRST channel and the REST.
105+ # However, for rigorous multi-class 'Asymmetric' implementation, we assume
106+ # class 0 is background (if included) and others are foreground.
122107
123- if self .to_onehot_y :
124- if n_pred_ch == 1 :
125- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
108+ loss_list = []
109+ for i in range (y_pred .shape [1 ]):
110+ # If this is the background channel (index 0 and included), use standard Dice loss
111+ if i == 0 and self .include_background :
112+ loss_list .append (1 - dice_class [:, i ])
126113 else :
127- y_true = one_hot (y_true , num_classes = n_pred_ch )
128-
129- if y_true .shape != y_pred .shape :
130- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
114+ # Foreground classes: apply focal modulation
115+ # Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
116+ loss_list .append ((1 - dice_class [:, i ]) * torch .pow (1 - dice_class [:, i ], - self .gamma ))
131117
132- y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
133- cross_entropy = - y_true * torch .log (y_pred )
134-
135- back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
136- back_ce = (1 - self .delta ) * back_ce
118+ loss = torch .stack (loss_list , dim = - 1 )
137119
138- fore_ce = cross_entropy [:, 1 ]
139- fore_ce = self .delta * fore_ce
120+ if self .reduction == LossReduction .SUM .value :
121+ return loss .sum ()
122+ if self .reduction == LossReduction .NONE .value :
123+ return loss
124+ if self .reduction == LossReduction .MEAN .value :
125+ return loss .mean ()
140126
141- loss = torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
142- return loss
127+ raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
143128
144129
145130class AsymmetricUnifiedFocalLoss (_Loss ):
146131 """
147- AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
148-
149- Actually, it's only supported for binary image segmentation now
150-
151- Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
152-
153- - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
154- Michael Yeung, Computerized Medical Imaging and Graphics
132+ AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Focal Loss and
133+ Asymmetric Focal Tversky Loss to handle class imbalance.
155134 """
156135
157136 def __init__ (
158137 self ,
159- to_onehot_y : bool = False ,
160- num_classes : int = 2 ,
161138 weight : float = 0.5 ,
139+ delta : float = 0.6 ,
162140 gamma : float = 0.5 ,
163- delta : float = 0.7 ,
141+ include_background : bool = True ,
142+ to_onehot_y : bool = False ,
143+ use_softmax : bool = False ,
164144 reduction : LossReduction | str = LossReduction .MEAN ,
165- ):
145+ ) -> None :
166146 """
167147 Args:
168- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169- num_classes : number of classes, it only supports 2 now. Defaults to 2.
170- delta : weight of the background. Defaults to 0.7.
171- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
173- weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
174-
175- Example:
176- >>> import torch
177- >>> from monai.losses import AsymmetricUnifiedFocalLoss
178- >>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
179- >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
180- >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
181- >>> fl(pred, grnd)
148+ weight: The weighting factor 'lambda' between Focal Loss and Asymmetric Focal Tversky Loss.
149+ delta: weight of the background class (used in Tversky). Defaults to 0.6.
150+ gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
151+ include_background: if False, channel index 0 (background category) is excluded.
152+ to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
153+ use_softmax: whether to use softmax to transform the original logits into probabilities.
154+ If True, softmax is used. If False, sigmoid is used. Defaults to False.
155+ reduction: Specifies the reduction to apply to the output. Defaults to ``"mean"``.
182156 """
183157 super ().__init__ (reduction = LossReduction (reduction ).value )
184- self .to_onehot_y = to_onehot_y
185- self .num_classes = num_classes
186- self .gamma = gamma
187- self .delta = delta
188- self .weight : float = weight
189- self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
190- self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
158+ self .weight = weight
159+ self .use_softmax = use_softmax # 儲存參數
160+
161+ self .focal_loss = FocalLoss (
162+ include_background = include_background ,
163+ to_onehot_y = to_onehot_y ,
164+ gamma = gamma ,
165+ reduction = reduction ,
166+ use_softmax = use_softmax ,
167+ )
168+
169+ self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
170+ include_background = include_background ,
171+ to_onehot_y = to_onehot_y ,
172+ delta = delta ,
173+ gamma = gamma ,
174+ reduction = reduction ,
175+ )
191176
192- # TODO: Implement this function to support multiple classes segmentation
193177 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
194178 """
195179 Args:
196- y_pred : the shape should be BNH[WD], where N is the number of classes.
197- It only supports binary segmentation.
198- The input should be the original logits since it will be transformed by
199- a sigmoid in the forward function.
200- y_true : the shape should be BNH[WD], where N is the number of classes.
201- It only supports binary segmentation.
202-
203- Raises:
204- ValueError: When input and target are different shape
205- ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
206- ValueError: When num_classes
207- ValueError: When the number of classes entered does not match the expected number
180+ y_pred: (BNH[WD]) Logits (raw scores).
181+ y_true: (BNH[WD]) Ground truth labels.
208182 """
209- if y_pred .shape != y_true .shape :
210- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
211-
212- if len (y_pred .shape ) != 4 and len (y_pred .shape ) != 5 :
213- raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
214-
215- if y_pred .shape [1 ] == 1 :
216- y_pred = one_hot (y_pred , num_classes = self .num_classes )
217- y_true = one_hot (y_true , num_classes = self .num_classes )
183+ focal_loss = self .focal_loss (y_pred , y_true )
218184
219- if torch .max (y_true ) != self .num_classes - 1 :
220- raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
221-
222- n_pred_ch = y_pred .shape [1 ]
223- if self .to_onehot_y :
224- if n_pred_ch == 1 :
225- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
226- else :
227- y_true = one_hot (y_true , num_classes = n_pred_ch )
185+ if self .use_softmax :
186+ y_pred_prob = torch .softmax (y_pred , dim = 1 )
187+ else :
188+ y_pred_prob = torch .sigmoid (y_pred )
228189
229- asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
230- asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
190+ tversky_loss = self .asy_focal_tversky_loss (y_pred_prob , y_true )
231191
232- loss : torch . Tensor = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
192+ loss = self .weight * focal_loss + (1 - self .weight ) * tversky_loss
233193
234- if self .reduction == LossReduction .SUM .value :
235- return torch .sum (loss ) # sum over the batch and channel dims
236- if self .reduction == LossReduction .NONE .value :
237- return loss # returns [N, num_classes] losses
238- if self .reduction == LossReduction .MEAN .value :
239- return torch .mean (loss )
240- raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
194+ return loss
0 commit comments