Skip to content

Commit f5a2f7e

Browse files
pre-commit-ci[bot]ytl0623
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci Signed-off-by: ytl0623 <[email protected]>
1 parent ad83444 commit f5a2f7e

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import warnings
1515

1616
import torch
17-
import torch.nn.functional as F
1817
from torch.nn.modules.loss import _Loss
1918

2019
from monai.networks import one_hot
@@ -169,7 +168,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169168
back_ce = (1.0 - self.delta) * torch.pow(1.0 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
170169
# (B, C-1, H, W)
171170
fore_ce = self.delta * cross_entropy[:, 1:]
172-
171+
173172
loss = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) # (B, C, H, W)
174173

175174
# Apply reduction
@@ -276,21 +275,22 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276275
if y_true.shape[1] != 1:
277276
y_true = y_true.unsqueeze(1)
278277
y_true = one_hot(y_true, num_classes=n_pred_ch)
279-
278+
280279
# Ensure y_true has the same shape as y_pred_act
281280
if y_true.shape != y_pred_act.shape:
282-
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
281+
# This can happen if y_true is (B, H, W) and y_pred is (B, 1, H, W) after sigmoid
283282
if y_true.shape[1] != y_pred_act.shape[1] and y_true.ndim == y_pred_act.ndim - 1:
284-
y_true = y_true.unsqueeze(1) # Add channel dim
285-
286-
if y_true.shape != y_pred_act.shape:
287-
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) " \
288-
f"after activations/one-hot")
283+
y_true = y_true.unsqueeze(1) # Add channel dim
289284

285+
if y_true.shape != y_pred_act.shape:
286+
raise ValueError(
287+
f"ground truth has different shape ({y_true.shape}) from input ({y_pred_act.shape}) "
288+
f"after activations/one-hot"
289+
)
290290

291291
f_loss = self.asy_focal_loss(y_pred_act, y_true)
292292
t_loss = self.asy_focal_tversky_loss(y_pred_act, y_true)
293293

294294
loss: torch.Tensor = self.lambda_focal * f_loss + (1 - self.lambda_focal) * t_loss
295295

296-
return loss
296+
return loss

0 commit comments

Comments
 (0)