1212from __future__ import annotations
1313
1414import unittest
15-
1615import numpy as np
1716import torch
1817from parameterized import parameterized
19-
2018from monai .losses import AsymmetricUnifiedFocalLoss
2119
20+ LOGIT_HIGH = 5.0
21+ LOGIT_LOW = - 5.0
22+
2223TEST_CASES = [
23- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
24+ # Case 0: Sigmoid + Include BG
25+ [
26+ {"use_softmax" : False , "include_background" : True },
27+ {
28+ "y_pred" : torch .tensor ([[[[LOGIT_HIGH ]], [[LOGIT_LOW ]]]]),
29+ "y_true" : torch .tensor ([[[[1.0 ]], [[0.0 ]]]]),
30+ },
31+ ],
32+ # Case 1: Softmax + Ignore BG
33+ [
34+ {"use_softmax" : True , "include_background" : False },
35+ {
36+ "y_pred" : torch .tensor ([[[[LOGIT_LOW ]], [[LOGIT_HIGH ]], [[LOGIT_LOW ]]]]),
37+ "y_true" : torch .tensor ([[[[0.0 ]], [[1.0 ]], [[0.0 ]]]]),
38+ },
39+ ],
40+ # Case 2: Softmax + Include BG
41+ [
42+ {"use_softmax" : True , "include_background" : True },
2443 {
25- "y_pred" : torch .tensor ([[[[1.0 , 0 ] , [0 , 1.0 ]]] , [[[ 1.0 , 0 ], [ 0 , 1.0 ]]]]),
26- "y_true" : torch .tensor ([[[[1.0 , 0 ] , [0 , 1 .0 ]]] , [[[ 1.0 , 0 ], [ 0 , 1 .0 ]]]]),
44+ "y_pred" : torch .tensor ([[[[LOGIT_HIGH ]] , [[ LOGIT_LOW ]] , [[LOGIT_LOW ] ]]]),
45+ "y_true" : torch .tensor ([[[[1.0 ]] , [[ 0 .0 ]], [[0 .0 ]]]]),
2746 },
28- 0.0 ,
2947 ],
30- [ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
48+ # Case 3: Sigmoid + Ignore BG
49+ [
50+ {"use_softmax" : False , "include_background" : False },
3151 {
32- "y_pred" : torch .tensor ([[[[1.0 , 0 ], [ 0 , 1.0 ]]] , [[[ 1.0 , 0 ], [ 0 , 1.0 ]]]]),
33- "y_true" : torch .tensor ([[[[1.0 , 0 ], [ 0 , 1 .0 ]]] , [[[ 1.0 , 0 ], [ 0 , 1.0 ]]]]),
52+ "y_pred" : torch .tensor ([[[[LOGIT_HIGH ]] , [[LOGIT_HIGH ]]]]),
53+ "y_true" : torch .tensor ([[[[0 .0 ]], [[1.0 ]]]]),
3454 },
35- 0.0 ,
3655 ],
3756]
3857
39-
4058class TestAsymmetricUnifiedFocalLoss (unittest .TestCase ):
4159
4260 @parameterized .expand (TEST_CASES )
43- def test_result (self , input_data , expected_val ):
44- loss = AsymmetricUnifiedFocalLoss ()
45- result = loss (** input_data )
46- np .testing .assert_allclose (result .detach ().cpu ().numpy (), expected_val , atol = 1e-4 , rtol = 1e-4 )
61+ def test_result (self , input_param , input_data ):
62+ loss_func = AsymmetricUnifiedFocalLoss (** input_param )
63+ result = loss_func (** input_data )
64+ res_val = result .detach ().cpu ().numpy ()
65+
66+ print (f"Params: { input_param } -> Loss: { res_val } " )
67+
68+
69+ self .assertFalse (np .isnan (res_val ), "Loss should not be NaN" )
70+ self .assertTrue (res_val < 1.0 , f"Loss { res_val } is too high (expected < 1.0)" )
4771
4872 def test_ill_shape (self ):
4973 loss = AsymmetricUnifiedFocalLoss ()
@@ -52,15 +76,19 @@ def test_ill_shape(self):
5276
5377 def test_with_cuda (self ):
5478 loss = AsymmetricUnifiedFocalLoss ()
55- i = torch .tensor ([[[[1 .0 , 0 ], [0 , 1 .0 ]]], [[[1 .0 , 0 ], [0 , 1 .0 ]]]])
79+ i = torch .tensor ([[[[5 .0 , - 5. 0 ], [- 5. 0 , 5 .0 ]]], [[[5 .0 , - 5. 0 ], [- 5. 0 , 5 .0 ]]]])
5680 j = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]])
81+
5782 if torch .cuda .is_available ():
5883 i = i .cuda ()
5984 j = j .cuda ()
85+ loss = loss .cuda ()
86+
6087 output = loss (i , j )
61- print (output )
62- np .testing .assert_allclose (output .detach ().cpu ().numpy (), 0.0 , atol = 1e-4 , rtol = 1e-4 )
63-
88+ res_val = output .detach ().cpu ().numpy ()
89+
90+ self .assertFalse (np .isnan (res_val ), "CUDA Loss should not be NaN" )
91+ self .assertTrue (res_val < 1.0 , f"CUDA Loss { res_val } is too high" )
6492
6593if __name__ == "__main__" :
66- unittest .main ()
94+ unittest .main ()
0 commit comments