Skip to content

Commit c45ee48

Browse files
author
Fabio Ferreira
committed
fix: tighten tolerance for numerical equivalence
Signed-off-by: Fabio Ferreira <[email protected]>
1 parent 814fa80 commit c45ee48

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/networks/nets/test_checkpointunet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,11 @@ def test_checkpointing_equivalence_eval(self):
151151
# Check shape equality
152152
self.assertEqual(y_ckpt.shape, y_plain.shape)
153153

154-
# Check numerical similarity
155-
diff = torch.mean(torch.abs(y_ckpt - y_plain)).item()
156-
self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})")
154+
# Check numerical equivalence
155+
self.assertTrue(
156+
torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5),
157+
f"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.2e}",
158+
)
157159

158160
def test_checkpointing_activates_training(self):
159161
"""Verify checkpointing recomputes activations during training."""

0 commit comments

Comments
 (0)