1919from monai .apps .generation .maisi .networks .autoencoderkl_maisi import AutoencoderKlMaisi
2020from monai .networks import eval_mode
2121from monai .utils import optional_import
22- from tests .test_utils import SkipIfBeforePyTorchVersion
2322
2423tqdm , has_tqdm = optional_import ("tqdm" , name = "tqdm" )
2524_ , has_einops = optional_import ("einops" )
@@ -87,7 +86,6 @@ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_s
8786 self .assertEqual (result [2 ].shape , expected_latent_shape )
8887
8988 @parameterized .expand (CASES )
90- @SkipIfBeforePyTorchVersion ((1 , 11 ))
9189 def test_shape_with_convtranspose_and_checkpointing (
9290 self , input_param , input_shape , expected_shape , expected_latent_shape
9391 ):
@@ -152,7 +150,6 @@ def test_shape_reconstruction(self):
152150 result = net .reconstruct (torch .randn (input_shape ).to (device ))
153151 self .assertEqual (result .shape , expected_shape )
154152
155- @SkipIfBeforePyTorchVersion ((1 , 11 ))
156153 def test_shape_reconstruction_with_convtranspose_and_checkpointing (self ):
157154 input_param , input_shape , expected_shape , _ = CASES [0 ]
158155 input_param = input_param .copy ()
@@ -170,7 +167,6 @@ def test_shape_encode(self):
170167 self .assertEqual (result [0 ].shape , expected_latent_shape )
171168 self .assertEqual (result [1 ].shape , expected_latent_shape )
172169
173- @SkipIfBeforePyTorchVersion ((1 , 11 ))
174170 def test_shape_encode_with_convtranspose_and_checkpointing (self ):
175171 input_param , input_shape , _ , expected_latent_shape = CASES [0 ]
176172 input_param = input_param .copy ()
@@ -190,7 +186,6 @@ def test_shape_sampling(self):
190186 )
191187 self .assertEqual (result .shape , expected_latent_shape )
192188
193- @SkipIfBeforePyTorchVersion ((1 , 11 ))
194189 def test_shape_sampling_convtranspose_and_checkpointing (self ):
195190 input_param , _ , _ , expected_latent_shape = CASES [0 ]
196191 input_param = input_param .copy ()
@@ -209,7 +204,6 @@ def test_shape_decode(self):
209204 result = net .decode (torch .randn (latent_shape ).to (device ))
210205 self .assertEqual (result .shape , expected_input_shape )
211206
212- @SkipIfBeforePyTorchVersion ((1 , 11 ))
213207 def test_shape_decode_convtranspose_and_checkpointing (self ):
214208 input_param , expected_input_shape , _ , latent_shape = CASES [0 ]
215209 input_param = input_param .copy ()
0 commit comments