Skip to content

Commit 1d196dc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ccc5459 commit 1d196dc

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from __future__ import annotations
1313

1414
import warnings
15-
from collections.abc import Sequence
1615

1716
import torch
18-
import torch.nn.functional as F
1917
from torch.nn.modules.loss import _Loss
2018

2119
from monai.losses import FocalLoss

tests/losses/test_unified_focal_loss.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@
3333
[
3434
{"use_softmax": True, "include_background": False},
3535
{
36-
"y_pred": torch.tensor([[[[LOGIT_LOW]], [[LOGIT_HIGH]], [[LOGIT_LOW]]]]),
36+
"y_pred": torch.tensor([[[[LOGIT_LOW]], [[LOGIT_HIGH]], [[LOGIT_LOW]]]]),
3737
"y_true": torch.tensor([[[[0.0]], [[1.0]], [[0.0]]]]),
3838
},
3939
],
4040
# Case 2: Softmax + Include BG
4141
[
4242
{"use_softmax": True, "include_background": True},
4343
{
44-
"y_pred": torch.tensor([[[[LOGIT_HIGH]], [[LOGIT_LOW]], [[LOGIT_LOW]]]]),
44+
"y_pred": torch.tensor([[[[LOGIT_HIGH]], [[LOGIT_LOW]], [[LOGIT_LOW]]]]),
4545
"y_true": torch.tensor([[[[1.0]], [[0.0]], [[0.0]]]]),
4646
},
4747
],
@@ -62,7 +62,7 @@ def test_result(self, input_param, input_data):
6262
loss_func = AsymmetricUnifiedFocalLoss(**input_param)
6363
result = loss_func(**input_data)
6464
res_val = result.detach().cpu().numpy()
65-
65+
6666
print(f"Params: {input_param} -> Loss: {res_val}")
6767

6868

@@ -78,17 +78,17 @@ def test_with_cuda(self):
7878
loss = AsymmetricUnifiedFocalLoss()
7979
i = torch.tensor([[[[5.0, -5.0], [-5.0, 5.0]]], [[[5.0, -5.0], [-5.0, 5.0]]]])
8080
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
81-
81+
8282
if torch.cuda.is_available():
8383
i = i.cuda()
8484
j = j.cuda()
8585
loss = loss.cuda()
86-
86+
8787
output = loss(i, j)
8888
res_val = output.detach().cpu().numpy()
89-
89+
9090
self.assertFalse(np.isnan(res_val), "CUDA Loss should not be NaN")
9191
self.assertTrue(res_val < 1.0, f"CUDA Loss {res_val} is too high")
9292

9393
if __name__ == "__main__":
94-
unittest.main()
94+
unittest.main()

0 commit comments

Comments
 (0)