@@ -131,7 +131,7 @@ def __init__(
131131 use_softmax: whether to use softmax to transform the original logits into probabilities.
132132 If True, softmax is used. If False, sigmoid is used. Defaults to False.
133133 delta : weight of the background. Defaults to 0.7.
134- gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75 .
134+ gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 2 .
135135 epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
136136 """
137137 super ().__init__ (reduction = LossReduction (reduction ).value )
@@ -166,13 +166,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
166166 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
167167 back_ce = (1 - self .delta ) * back_ce
168168
169- fore_ce = cross_entropy [:, 1 :]
170- fore_ce = self .delta * fore_ce
169+ if n_pred_ch > 1 :
170+ fore_ce = cross_entropy [:, 1 :]
171+ fore_ce = self .delta * fore_ce
171172
172- if fore_ce .shape [1 ] > 1 :
173- fore_ce = torch .sum (fore_ce , dim = 1 )
173+ if fore_ce .shape [1 ] > 1 :
174+ fore_ce = torch .sum (fore_ce , dim = 1 )
175+ else :
176+ fore_ce = fore_ce .squeeze (1 )
174177 else :
175- fore_ce = fore_ce . squeeze ( 1 )
178+ fore_ce = torch . zeros_like ( back_ce )
176179
177180 loss = torch .mean (torch .stack ([back_ce , fore_ce ], dim = - 1 ))
178181 return loss
0 commit comments