Skip to content

Commit 3d7dfc6

Browse files
committed
Fix perceptualLoss normalize_tensor to prevent NaN errors
1 parent 83dcd35 commit 3d7dfc6

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

monai/losses/perceptual.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
271271
return x.mean([2, 3, 4], keepdim=keepdim)
272272

273273

274-
def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
275-
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
274+
def normalize_tensor(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
275+
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True) + eps)
276276
return x / (norm_factor + eps)
277277

278278

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
import torch.optim as optim
18+
19+
from parameterized import parameterized
20+
21+
from monai.losses.perceptual import normalize_tensor
22+
from monai.utils import set_determinism
23+
24+
25+
class TestNormalizeTensorStability(unittest.TestCase):
26+
def setUp(self):
27+
set_determinism(seed=0)
28+
self.addCleanup(set_determinism, None)
29+
30+
def tearDown(self):
31+
set_determinism(None)
32+
33+
@parameterized.expand(
34+
[
35+
["e-3", 1e-3],
36+
["e-6", 1e-6],
37+
["e-9", 1e-9],
38+
["e-12", 1e-12], # Small values
39+
]
40+
)
41+
def test_normalize_tensor_stability(self, name, scale):
42+
"""Test that small values don't produce NaNs + are handled gracefully."""
43+
# Create tensor
44+
x = torch.zeros(2, 3, 10, 10, requires_grad=True)
45+
46+
optimizer = optim.Adam([x], lr=0.01)
47+
x_scaled = x * scale
48+
normalized = normalize_tensor(x_scaled)
49+
50+
# Compute to force backward pass
51+
loss = normalized.sum()
52+
53+
# this is where it failed before
54+
loss.backward()
55+
56+
# Check for NaNs in gradients
57+
self.assertFalse(
58+
torch.isnan(x.grad).any(),
59+
f"NaN gradients detected with scale {scale:.10e}"
60+
)
61+
62+
def test_normalize_tensor_zero_input(self):
63+
"""Test that normalize_tensor handles zero inputs gracefully."""
64+
# Create tensor with zeros
65+
x = torch.zeros(2, 3, 4, 4, requires_grad=True)
66+
67+
normalized = normalize_tensor(x)
68+
loss = normalized.sum()
69+
loss.backward()
70+
71+
# Check for NaNs in gradients
72+
self.assertFalse(torch.isnan(x.grad).any(), "NaN gradients detected with zero input")
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)