Skip to content

Commit 100e9f8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent dca758d commit 100e9f8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169169
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
170170
# (B, C-1, H, W)
171171
fore_ce = self.delta * cross_entropy[:, 1:]
172-
172+
173173
loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)
174174

175175
# Apply reduction
@@ -276,15 +276,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276276
if y_true.shape[1] != 1:
277277
y_true = y_true.unsqueeze(1)
278278
y_true = one_hot(y_true, num_classes=n_pred_ch)
279-
279+
280280
# Ensure y_true has the same shape as y_pred_act
281281
if y_true.shape != y_pred_act.shape:
282282
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283283
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
284284
y_true = y_true.unsqueeze(1) # Add channel dim
285-
285+
286286
if y_true.shape != y_pred_act.shape:
287-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape})
287+
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape})
288288
after activations/one-hot")
289289

290290

@@ -293,4 +293,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
293293

294294
loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss
295295

296-
return loss
296+
return loss

0 commit comments

Comments
 (0)