@@ -70,19 +70,21 @@ def __init__(
7070 include_background : bool = True ,
7171 to_onehot_y : bool = False ,
7272 gamma : float = 2.0 ,
73- alpha : float | None = None ,
73+ alpha : float | Sequence [ float ] | None = None ,
7474 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
7575 reduction : LossReduction | str = LossReduction .MEAN ,
7676 use_softmax : bool = False ,
7777 ) -> None :
7878 """
7979 Args:
8080 include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
81- If False, `alpha` is invalid when using softmax.
81+ If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights) .
8282 to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
8383 gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
8484 alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
85- The value should be in [0, 1]. Defaults to None.
85+ The value should be in [0, 1].
86+ If a sequence is provided, it must match the number of classes (after excluding background if set).
87+ Defaults to None.
8688 weight: weights to apply to the voxels of each class. If None no weights are applied.
8789 The input can be a single value (same weight for all classes), a sequence of values (the length
8890 of the sequence should be the same as the number of classes. If not ``include_background``,
@@ -156,13 +158,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
156158 loss : Optional [torch .Tensor ] = None
157159 input = input .float ()
158160 target = target .float ()
161+
162+ alpha_arg = self .alpha
159163 if self .use_softmax :
160164 if not self .include_background and self .alpha is not None :
161- self .alpha = None
162- warnings .warn ("`include_background=False`, `alpha` ignored when using softmax." )
163- loss = softmax_focal_loss (input , target , self .gamma , self .alpha )
165+ if isinstance (self .alpha , (float , int )):
166+ alpha_arg = None
167+ warnings .warn ("`include_background=False`, scalar `alpha` ignored when using softmax." )
168+ loss = softmax_focal_loss (input , target , self .gamma , self .alpha_arg )
164169 else :
165- loss = sigmoid_focal_loss (input , target , self .gamma , self .alpha )
170+ loss = sigmoid_focal_loss (input , target , self .gamma , self .alpha_arg )
166171
167172 num_of_classes = target .shape [1 ]
168173 if self .class_weight is not None and num_of_classes != 1 :
@@ -203,7 +208,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
203208
204209
205210def softmax_focal_loss (
206- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
211+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | Sequence [float ] | None = None
207212) -> torch .Tensor :
208213 """
209214 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -215,8 +220,18 @@ def softmax_focal_loss(
215220 loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
216221
217222 if alpha is not None :
218- # (1-alpha) for the background class and alpha for the other classes
219- alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
223+ alpha_t = torch .as_tensor (alpha , device = input .device , dtype = input .dtype )
224+
225+ if alpha_t .ndim == 0 : # scalar
226+ # (1-alpha) for the background class and alpha for the other classes
227+ alpha_fac = torch .tensor ([1 - alpha ] + [alpha ] * (target .shape [1 ] - 1 )).to (loss )
228+ else : # sequence
229+ if alpha_t .shape [0 ] != target .shape [1 ]:
230+ raise ValueError (
231+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
232+ )
233+ alpha_fac = alpha_t
234+
220235 broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
221236 alpha_fac = alpha_fac .view (broadcast_dims )
222237 loss = alpha_fac * loss
@@ -225,7 +240,7 @@ def softmax_focal_loss(
225240
226241
227242def sigmoid_focal_loss (
228- input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : Optional [float ] = None
243+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | Sequence [float ] | None = None
229244) -> torch .Tensor :
230245 """
231246 FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
@@ -248,8 +263,21 @@ def sigmoid_focal_loss(
248263 loss = (invprobs * gamma ).exp () * loss
249264
250265 if alpha is not None :
251- # alpha if t==1; (1-alpha) if t==0
252- alpha_factor = target * alpha + (1 - target ) * (1 - alpha )
266+ alpha_t = torch .as_tensor (alpha , device = input .device , dtype = input .dtype )
267+ if alpha_t .ndim == 0 : # scalar
268+ # alpha if t==1; (1-alpha) if t==0
269+ alpha_factor = target * alpha_t + (1 - target ) * (1 - alpha_t )
270+ else : # sequence / per-channel alpha
271+ if alpha_t .shape [0 ] != target .shape [1 ]:
272+ raise ValueError (
273+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
274+ )
275+ # Reshape alpha for broadcasting: (1, C, 1, 1...)
276+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
277+ alpha_t = alpha_t .view (broadcast_dims )
278+ # Apply alpha_c if t==1, (1-alpha_c) if t==0 for channel c
279+ alpha_factor = target * alpha_t + (1 - target ) * (1 - alpha_t )
280+
253281 loss = alpha_factor * loss
254282
255283 return loss
0 commit comments