Skip to content

Commit ccc5459

Browse files
committed
update test_unified_focal_loss.py
Signed-off-by: ytl0623 <[email protected]>
1 parent c27945a commit ccc5459

File tree

1 file changed

+48
-20
lines changed

1 file changed

+48
-20
lines changed

tests/losses/test_unified_focal_loss.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,62 @@
1212
from __future__ import annotations
1313

1414
import unittest
15-
1615
import numpy as np
1716
import torch
1817
from parameterized import parameterized
19-
2018
from monai.losses import AsymmetricUnifiedFocalLoss
2119

20+
LOGIT_HIGH = 5.0
21+
LOGIT_LOW = -5.0
22+
2223
TEST_CASES = [
23-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
24+
# Case 0: Sigmoid + Include BG
25+
[
26+
{"use_softmax": False, "include_background": True},
27+
{
28+
"y_pred": torch.tensor([[[[LOGIT_HIGH]], [[LOGIT_LOW]]]]),
29+
"y_true": torch.tensor([[[[1.0]], [[0.0]]]]),
30+
},
31+
],
32+
# Case 1: Softmax + Ignore BG
33+
[
34+
{"use_softmax": True, "include_background": False},
35+
{
36+
"y_pred": torch.tensor([[[[LOGIT_LOW]], [[LOGIT_HIGH]], [[LOGIT_LOW]]]]),
37+
"y_true": torch.tensor([[[[0.0]], [[1.0]], [[0.0]]]]),
38+
},
39+
],
40+
# Case 2: Softmax + Include BG
41+
[
42+
{"use_softmax": True, "include_background": True},
2443
{
25-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
26-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
44+
"y_pred": torch.tensor([[[[LOGIT_HIGH]], [[LOGIT_LOW]], [[LOGIT_LOW]]]]),
45+
"y_true": torch.tensor([[[[1.0]], [[0.0]], [[0.0]]]]),
2746
},
28-
0.0,
2947
],
30-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
48+
# Case 3: Sigmoid + Ignore BG
49+
[
50+
{"use_softmax": False, "include_background": False},
3151
{
32-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
33-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
52+
"y_pred": torch.tensor([[[[LOGIT_HIGH]], [[LOGIT_HIGH]]]]),
53+
"y_true": torch.tensor([[[[0.0]], [[1.0]]]]),
3454
},
35-
0.0,
3655
],
3756
]
3857

39-
4058
class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
4159

4260
@parameterized.expand(TEST_CASES)
43-
def test_result(self, input_data, expected_val):
44-
loss = AsymmetricUnifiedFocalLoss()
45-
result = loss(**input_data)
46-
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
61+
def test_result(self, input_param, input_data):
62+
loss_func = AsymmetricUnifiedFocalLoss(**input_param)
63+
result = loss_func(**input_data)
64+
res_val = result.detach().cpu().numpy()
65+
66+
print(f"Params: {input_param} -> Loss: {res_val}")
67+
68+
69+
self.assertFalse(np.isnan(res_val), "Loss should not be NaN")
70+
self.assertTrue(res_val < 1.0, f"Loss {res_val} is too high (expected < 1.0)")
4771

4872
def test_ill_shape(self):
4973
loss = AsymmetricUnifiedFocalLoss()
@@ -52,15 +76,19 @@ def test_ill_shape(self):
5276

5377
def test_with_cuda(self):
5478
loss = AsymmetricUnifiedFocalLoss()
55-
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
79+
i = torch.tensor([[[[5.0, -5.0], [-5.0, 5.0]]], [[[5.0, -5.0], [-5.0, 5.0]]]])
5680
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
81+
5782
if torch.cuda.is_available():
5883
i = i.cuda()
5984
j = j.cuda()
85+
loss = loss.cuda()
86+
6087
output = loss(i, j)
61-
print(output)
62-
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
63-
88+
res_val = output.detach().cpu().numpy()
89+
90+
self.assertFalse(np.isnan(res_val), "CUDA Loss should not be NaN")
91+
self.assertTrue(res_val < 1.0, f"CUDA Loss {res_val} is too high")
6492

6593
if __name__ == "__main__":
66-
unittest.main()
94+
unittest.main()

0 commit comments

Comments
 (0)