@@ -169,7 +169,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169169 back_ce = (1.0 - self .delta ) * torch .pow (1.0 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
170170 # (B, C-1, H, W)
171171 fore_ce = self .delta * cross_entropy [:, 1 :]
172-
172+
173173 loss = torch .cat ([back_ce .unsqueeze (1 ), fore_ce ], dim = 1 ) # (B, C, H, W)
174174
175175 # Apply reduction
@@ -276,15 +276,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276276 if y_true .shape [1 ] != 1 :
277277 y_true = y_true .unsqueeze (1 )
278278 y_true = one_hot (y_true , num_classes = n_pred_ch )
279-
279+
280280 # Ensure y_true has the same shape as y_pred_act
281281 if y_true .shape != y_pred_act .shape :
282282 # This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283283 if y_true .shape [1 ] != y_pred_act .shape [1 ] and y_true .ndim == y_pred_act .ndim - 1 :
284284 y_true = y_true .unsqueeze (1 ) # Add channel dim
285-
285+
286286 if y_true .shape != y_pred_act .shape :
287- raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } )
287+ raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred_act .shape } )
288288 after activations / one - hot ")
289289
290290
@@ -293,4 +293,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
293293
294294 loss : torch .Tensor = self .lambda_focal * f_loss + (1 - self .lambda_focal ) * t_loss
295295
296- return loss
296+ return loss
0 commit comments