Skip to content

Commit 41dccad

Browse files
committed
update test_unified_focal_loss.py
Signed-off-by: ytl0623 <[email protected]>
1 parent 1fba9d3 commit 41dccad

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
111111
else:
112112
# Foreground classes: apply focal modulation
113113
# Original logic: (1 - dice) * (1 - dice)^(-gamma) -> (1-dice)^(1-gamma)
114-
loss_list.append((1 - dice_class[:, i]) * torch.pow(1 - dice_class[:, i], -self.gamma))
114+
back_dice = torch.clamp(1 - dice_class[:, i], min=self.epsilon)
115+
loss_list.append(back_dice * torch.pow(back_dice, -self.gamma))
115116

116117
loss = torch.stack(loss_list, dim=-1)
117118

@@ -154,7 +155,7 @@ def __init__(
154155
"""
155156
super().__init__(reduction=LossReduction(reduction).value)
156157
self.weight = weight
157-
self.use_softmax = use_softmax # 儲存參數
158+
self.use_softmax = use_softmax
158159

159160
self.focal_loss = FocalLoss(
160161
include_background=include_background,

tests/losses/test_unified_focal_loss.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,42 @@
1919

2020
from monai.losses import AsymmetricUnifiedFocalLoss
2121

22+
# Helper to create high confidence logits (approx 10 -> sigmoid close to 1, -10 -> sigmoid close to 0)
23+
logit_pos = 10.0
24+
logit_neg = -10.0
25+
2226
TEST_CASES = [
23-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
27+
# Case 0: Binary Segmentation (Sigmoid), Perfect Prediction
28+
# Shape: (B=2, C=1, H=2, W=2)
29+
[
30+
{
31+
"use_softmax": False,
32+
"include_background": True,
33+
},
2434
{
25-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
26-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
35+
# Logits: High value where ground truth is 1, Low value where ground truth is 0
36+
"y_pred": torch.tensor(
37+
[[[[logit_pos, logit_neg], [logit_neg, logit_pos]]], [[[logit_pos, logit_neg], [logit_neg, logit_pos]]]]
38+
),
39+
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]], [[[1.0, 0.0], [0.0, 1.0]]]]),
2740
},
2841
0.0,
2942
],
30-
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
43+
# Case 1: Multi-class Segmentation (Softmax), Perfect Prediction
44+
# Shape: (B=1, C=3, H=1, W=2) -> 3 classes (0: Background, 1: Class A, 2: Class B)
45+
[
46+
{
47+
"use_softmax": True,
48+
"include_background": True,
49+
},
3150
{
32-
"y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
33-
"y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),
51+
# Logits:
52+
# Pixel 1: Class 0 is target -> [10, -10, -10]
53+
# Pixel 2: Class 2 is target -> [-10, -10, 10]
54+
"y_pred": torch.tensor(
55+
[[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]] # Ch 0 # Ch 1 # Ch 2
56+
),
57+
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 0.0], [0.0, 1.0]]]]), # Ch 0 (Background) # Ch 1 # Ch 2
3458
},
3559
0.0,
3660
],
@@ -40,8 +64,8 @@
4064
class TestAsymmetricUnifiedFocalLoss(unittest.TestCase):
4165

4266
@parameterized.expand(TEST_CASES)
43-
def test_result(self, input_data, expected_val):
44-
loss = AsymmetricUnifiedFocalLoss()
67+
def test_result(self, input_param, input_data, expected_val):
68+
loss = AsymmetricUnifiedFocalLoss(**input_param)
4569
result = loss(**input_data)
4670
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
4771

@@ -51,12 +75,15 @@ def test_ill_shape(self):
5175
loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2)))
5276

5377
def test_with_cuda(self):
54-
loss = AsymmetricUnifiedFocalLoss()
55-
i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
56-
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
78+
loss = AsymmetricUnifiedFocalLoss(use_softmax=False)
79+
i = torch.tensor([[[[logit_pos, logit_neg], [logit_neg, logit_pos]]]])
80+
j = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])
81+
5782
if torch.cuda.is_available():
5883
i = i.cuda()
5984
j = j.cuda()
85+
loss = loss.cuda()
86+
6087
output = loss(i, j)
6188
print(output)
6289
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)