1515
1616import torch
1717import torch .optim as optim
18-
1918from parameterized import parameterized
2019
2120from monai .losses .perceptual import normalize_tensor
@@ -30,44 +29,34 @@ def setUp(self):
3029 def tearDown (self ):
3130 set_determinism (None )
3231
33- @parameterized .expand (
34- [
35- ["e-3" , 1e-3 ],
36- ["e-6" , 1e-6 ],
37- ["e-9" , 1e-9 ],
38- ["e-12" , 1e-12 ], # Small values
39- ]
40- )
32+ @parameterized .expand ([["e-3" , 1e-3 ], ["e-6" , 1e-6 ], ["e-9" , 1e-9 ], ["e-12" , 1e-12 ]]) # Small values
4133 def test_normalize_tensor_stability (self , name , scale ):
4234 """Test that small values don't produce NaNs + are handled gracefully."""
43- # Create tensor
35+ # Create tensor
4436 x = torch .zeros (2 , 3 , 10 , 10 , requires_grad = True )
45-
37+
4638 optimizer = optim .Adam ([x ], lr = 0.01 )
4739 x_scaled = x * scale
4840 normalized = normalize_tensor (x_scaled )
49-
41+
5042 # Compute to force backward pass
5143 loss = normalized .sum ()
52-
44+
5345 # this is where it failed before
5446 loss .backward ()
5547
5648 # Check for NaNs in gradients
57- self .assertFalse (
58- torch .isnan (x .grad ).any (),
59- f"NaN gradients detected with scale { scale :.10e} "
60- )
61-
49+ self .assertFalse (torch .isnan (x .grad ).any (), f"NaN gradients detected with scale { scale :.10e} " )
50+
6251 def test_normalize_tensor_zero_input (self ):
6352 """Test that normalize_tensor handles zero inputs gracefully."""
64- # Create tensor with zeros
53+ # Create tensor with zeros
6554 x = torch .zeros (2 , 3 , 4 , 4 , requires_grad = True )
66-
55+
6756 normalized = normalize_tensor (x )
6857 loss = normalized .sum ()
6958 loss .backward ()
70-
59+
7160 # Check for NaNs in gradients
7261 self .assertFalse (torch .isnan (x .grad ).any (), "NaN gradients detected with zero input" )
7362
0 commit comments