@@ -48,8 +48,8 @@ def __init__(
4848 use_softmax: whether to use softmax to transform the original logits into probabilities.
4949 If True, softmax is used. If False, sigmoid is used. Defaults to False.
5050 delta : weight of the background. Defaults to 0.7.
51- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
52- epsilon : it defines a very small number each time. similarly smooth value . Defaults to 1e-7.
51+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
52+ epsilon : stability factor used to avoid division by zero . Defaults to 1e-7.
5353 """
5454 super ().__init__ (reduction = LossReduction (reduction ).value )
5555 self .to_onehot_y = to_onehot_y
@@ -59,6 +59,17 @@ def __init__(
5959 self .epsilon = epsilon
6060
6161 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
62+ n_pred_ch = y_pred .shape [1 ]
63+
64+ if self .use_softmax and n_pred_ch == 1 :
65+ raise ValueError ("single channel prediction with `use_softmax=True` is not allowed." )
66+
67+ if self .to_onehot_y :
68+ if n_pred_ch == 1 :
69+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
70+ else :
71+ y_true = one_hot (y_true , num_classes = n_pred_ch )
72+
6273 if self .use_softmax :
6374 y_pred = torch .softmax (y_pred , dim = 1 )
6475 else :
@@ -68,17 +79,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6879 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
6980 y_true = torch .cat ([1 - y_true , y_true ], dim = 1 )
7081
71- n_pred_ch = y_pred .shape [1 ]
72-
73- if self .to_onehot_y :
74- if n_pred_ch == 1 :
75- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
76- else :
77- y_true = one_hot (y_true , num_classes = n_pred_ch )
78-
7982 if y_true .shape != y_pred .shape :
8083 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
8184
85+ # Calculate Loss
8286 axis = list (range (2 , len (y_pred .shape )))
8387
8488 # Calculate true positives (tp), false negatives (fn) and false positives (fp)
@@ -130,8 +134,8 @@ def __init__(
130134 use_softmax: whether to use softmax to transform the original logits into probabilities.
131135 If True, softmax is used. If False, sigmoid is used. Defaults to False.
132136 delta : weight of the background. Defaults to 0.7.
133- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2.
134- epsilon : it defines a very small number each time. similarly smooth value . Defaults to 1e-7.
137+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
138+ epsilon : stability factor used to avoid division by zero . Defaults to 1e-7.
135139 """
136140 super ().__init__ (reduction = LossReduction (reduction ).value )
137141 self .to_onehot_y = to_onehot_y
@@ -141,26 +145,34 @@ def __init__(
141145 self .epsilon = epsilon
142146
143147 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
148+ n_pred_ch = y_pred .shape [1 ]
149+
150+ if self .use_softmax and n_pred_ch == 1 :
151+ raise ValueError ("single channel prediction with `use_softmax=True` is not allowed." )
152+
153+ if self .to_onehot_y :
154+ if n_pred_ch == 1 :
155+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
156+ else :
157+ y_true = one_hot (y_true , num_classes = n_pred_ch )
158+
159+ # Save logits for numerical stability in single-channel expansion
160+ y_logits = y_pred
161+
144162 if self .use_softmax :
145163 y_log_pred = F .log_softmax (y_pred , dim = 1 )
146164 y_pred = torch .exp (y_log_pred )
147165 else :
148166 y_log_pred = F .logsigmoid (y_pred )
149167 y_pred = torch .sigmoid (y_pred )
150168
151- if y_pred .shape [1 ] == 1 :
169+ # Handle Single Channel (Binary) Expansion
170+ if n_pred_ch == 1 :
152171 y_pred = torch .cat ([1 - y_pred , y_pred ], dim = 1 )
153- y_log_pred = torch .log (torch .clamp (y_pred , 1e-7 , 1.0 ))
172+ bg_log_pred = F .logsigmoid (- y_logits )
173+ y_log_pred = torch .cat ([bg_log_pred , y_log_pred ], dim = 1 )
154174 y_true = torch .cat ([1 - y_true , y_true ], dim = 1 )
155175
156- n_pred_ch = y_pred .shape [1 ]
157-
158- if self .to_onehot_y :
159- if n_pred_ch == 1 :
160- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
161- else :
162- y_true = one_hot (y_true , num_classes = n_pred_ch )
163-
164176 if y_true .shape != y_pred .shape :
165177 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
166178
@@ -199,20 +211,18 @@ class AsymmetricUnifiedFocalLoss(_Loss):
199211 def __init__ (
200212 self ,
201213 to_onehot_y : bool = False ,
202- weight : float = 0.5 ,
203- gamma : float = 0.5 ,
204- delta : float = 0.7 ,
205214 use_softmax : bool = False ,
215+ delta : float = 0.7 ,
216+ gamma : float = 2 ,
206217 reduction : LossReduction | str = LossReduction .MEAN ,
207218 ):
208219 """
209220 Args:
210- to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
211- weight : weight for each loss function. Defaults to 0.5.
212- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
213- delta : weight of the background. Defaults to 0.7.
221+ to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
214222 use_softmax: whether to use softmax to transform the original logits into probabilities.
215223 If True, softmax is used. If False, sigmoid is used. Defaults to False.
224+ delta : weight of the background. Defaults to 0.7.
225+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
216226
217227 Example:
218228 >>> import torch
@@ -250,9 +260,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
250260 loss : torch .Tensor = self .weight * asy_focal_loss + (1 - self .weight ) * asy_focal_tversky_loss
251261
252262 if self .reduction == LossReduction .SUM .value :
253- return torch .sum (loss ) # sum over the batch and channel dims
263+ return torch .sum (loss )
254264 if self .reduction == LossReduction .NONE .value :
255- return loss # returns [N, num_classes] losses
265+ return loss
256266 if self .reduction == LossReduction .MEAN .value :
257267 return torch .mean (loss )
258268 raise ValueError (f'Unsupported reduction: { self .reduction } , available options are ["mean", "sum", "none"].' )
0 commit comments