1919
2020from monai .losses import AsymmetricUnifiedFocalLoss
2121
22+ # Helper to create high confidence logits (approx 10 -> sigmoid close to 1, -10 -> sigmoid close to 0)
23+ logit_pos = 10.0
24+ logit_neg = - 10.0
25+
2226TEST_CASES = [
23- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
27+ # Case 0: Binary Segmentation (Sigmoid), Perfect Prediction
28+ # Shape: (B=2, C=1, H=2, W=2)
29+ [
30+ {
31+ "use_softmax" : False ,
32+ "include_background" : True ,
33+ },
2434 {
25- "y_pred" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
26- "y_true" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
35+ # Logits: High value where ground truth is 1, Low value where ground truth is 0
36+ "y_pred" : torch .tensor (
37+ [[[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]], [[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]]]
38+ ),
39+ "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]], [[[1.0 , 0.0 ], [0.0 , 1.0 ]]]]),
2740 },
2841 0.0 ,
2942 ],
30- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
43+ # Case 1: Multi-class Segmentation (Softmax), Perfect Prediction
44+ # Shape: (B=1, C=3, H=1, W=2) -> 3 classes (0: Background, 1: Class A, 2: Class B)
45+ [
46+ {
47+ "use_softmax" : True ,
48+ "include_background" : True ,
49+ },
3150 {
32- "y_pred" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
33- "y_true" : torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]),
51+ # Logits:
52+ # Pixel 1: Class 0 is target -> [10, -10, -10]
53+ # Pixel 2: Class 2 is target -> [-10, -10, 10]
54+ "y_pred" : torch .tensor (
55+ [[[[logit_pos , logit_neg ], [logit_neg , logit_neg ], [logit_neg , logit_pos ]]]] # Ch 0 # Ch 1 # Ch 2
56+ ),
57+ "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 0.0 ], [0.0 , 1.0 ]]]]), # Ch 0 (Background) # Ch 1 # Ch 2
3458 },
3559 0.0 ,
3660 ],
4064class TestAsymmetricUnifiedFocalLoss (unittest .TestCase ):
4165
4266 @parameterized .expand (TEST_CASES )
43- def test_result (self , input_data , expected_val ):
44- loss = AsymmetricUnifiedFocalLoss ()
67+ def test_result (self , input_param , input_data , expected_val ):
68+ loss = AsymmetricUnifiedFocalLoss (** input_param )
4569 result = loss (** input_data )
4670 np .testing .assert_allclose (result .detach ().cpu ().numpy (), expected_val , atol = 1e-4 , rtol = 1e-4 )
4771
@@ -51,12 +75,15 @@ def test_ill_shape(self):
5175 loss (torch .ones ((2 , 2 , 2 )), torch .ones ((2 , 2 , 2 , 2 )))
5276
5377 def test_with_cuda (self ):
54- loss = AsymmetricUnifiedFocalLoss ()
55- i = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]])
56- j = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]])
78+ loss = AsymmetricUnifiedFocalLoss (use_softmax = False )
79+ i = torch .tensor ([[[[logit_pos , logit_neg ], [logit_neg , logit_pos ]]]])
80+ j = torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]])
81+
5782 if torch .cuda .is_available ():
5883 i = i .cuda ()
5984 j = j .cuda ()
85+ loss = loss .cuda ()
86+
6087 output = loss (i , j )
6188 print (output )
6289 np .testing .assert_allclose (output .detach ().cpu ().numpy (), 0.0 , atol = 1e-4 , rtol = 1e-4 )
0 commit comments