Skip to content

Commit d2cc08f

Browse files
committed
Updates to test precision to account for float32 conversion, this should be removed when PyTorch is updated.
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent a1f8df9 commit d2cc08f

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

environment-dev-test.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: monai
2+
channels:
3+
- pytorch
4+
- defaults
5+
- nvidia
6+
- conda-forge
7+
dependencies:
8+
- numpy>=1.17
9+
- pytorch>=1.8
10+
- torchvision
11+
- pytorch-cuda=11.6
12+
- pip
13+
- pip:
14+
- -r requirements-dev.txt

tests/integration/test_pad_collation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from __future__ import annotations
1313

14+
import os
1415
import random
1516
import unittest
17+
from contextlib import redirect_stderr
1618
from functools import wraps
1719

1820
import numpy as np
@@ -35,7 +37,7 @@
3537
RandZoomd,
3638
ToTensor,
3739
)
38-
from monai.utils import set_determinism
40+
from monai.utils import first, set_determinism
3941

4042

4143
@wraps(pad_list_data_collate)
@@ -97,8 +99,9 @@ def test_pad_collation(self, t_type, collate_method, transform):
9799
# Default collation should raise an error
98100
loader_fail = DataLoader(dataset, batch_size=10)
99101
with self.assertRaises(RuntimeError):
100-
for _ in loader_fail:
101-
pass
102+
# stifle PyTorch error reporting, we expect failure so don't need to look at it
103+
with open(os.devnull) as f, redirect_stderr(f):
104+
_ = first(loader_fail)
102105

103106
# Padded collation shouldn't
104107
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)

tests/lazy_transforms_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from copy import deepcopy
15+
import sys
1516

1617
from monai.data import MetaTensor, set_track_meta
1718
from monai.transforms import InvertibleTransform, MapTransform, Randomizable
@@ -62,6 +63,13 @@ def test_resampler_lazy(
6263
resampler.set_random_state(seed=seed)
6364
set_track_meta(True)
6465
resampler.lazy = True
66+
67+
# FIXME: this is a fix for https://github.com/Project-MONAI/MONAI/pull/8429, remove when PyTorch has
68+
# fixed the underlying issue
69+
if sys.platform == "win32":
70+
atol=1e-4
71+
rtol=1e-4
72+
6573
pending_output = resampler(**deepcopy(call_param))
6674
if output_idx is not None:
6775
expected_output, pending_output = (expected_output[output_idx], pending_output[output_idx])

0 commit comments

Comments
 (0)