@@ -49,7 +49,7 @@ def __init__(
4949 If True, softmax is used. If False, sigmoid is used. Defaults to False.
5050 delta : weight of the background. Defaults to 0.7.
5151 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
52- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
52+ epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
5353 """
5454 super ().__init__ (reduction = LossReduction (reduction ).value )
5555 self .to_onehot_y = to_onehot_y
@@ -88,12 +88,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8888
8989 # Calculate losses separately for each class, enhancing both classes
9090 back_dice = 1 - dice_class [:, 0 ]
91- fore_dice = (1 - dice_class [:, 1 :]) * torch .pow (1 - dice_class [:, 1 :], - self .gamma )
9291
93- if fore_dice .shape [1 ] > 1 :
94- fore_dice = torch .mean (fore_dice , dim = 1 )
92+ if n_pred_ch > 1 :
93+ fore_dice = torch .pow (1 - dice_class [:, 1 :], 1 - self .gamma )
94+
95+ if fore_dice .shape [1 ] > 1 :
96+ fore_dice = torch .mean (fore_dice , dim = 1 )
97+ else :
98+ fore_dice = fore_dice .squeeze (1 )
9599 else :
96- fore_dice = fore_dice . squeeze ( 1 )
100+ fore_dice = torch . zeros_like ( back_dice )
97101
98102 # Average class scores
99103 loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
@@ -128,7 +132,7 @@ def __init__(
128132 If True, softmax is used. If False, sigmoid is used. Defaults to False.
129133 delta : weight of the background. Defaults to 0.7.
130134 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
131- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
135+ epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
132136 """
133137 super ().__init__ (reduction = LossReduction (reduction ).value )
134138 self .to_onehot_y = to_onehot_y
@@ -198,8 +202,8 @@ def __init__(
198202 """
199203 Args:
200204 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
201- weight : weight for each loss function, if it's none it's 0.5. Defaults to None .
202- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75 .
205+ weight : weight for each loss function. Defaults to 0.5 .
206+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5 .
203207 delta : weight of the background. Defaults to 0.7.
204208 use_softmax: whether to use softmax to transform the original logits into probabilities.
205209 If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -235,7 +239,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
235239
236240 Raises:
237241 ValueError: When input and target are different shape
238- ValueError: When the number of classes entered does not match the expected number
239242 """
240243 if y_pred .shape != y_true .shape :
241244 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
0 commit comments