1818
1919from monai .networks import eval_mode
2020from monai .networks .blocks import DownSample , MaxAvgPool , SubpixelDownsample , SubpixelUpsample
21+ from monai .utils import optional_import
22+
23+ einops , has_einops = optional_import ("einops" )
2124
2225TEST_CASES = [
2326 [{"spatial_dims" : 2 , "kernel_size" : 2 }, (7 , 4 , 64 , 48 ), (7 , 8 , 32 , 24 )], # 4-channel 2D, batch 7
@@ -82,7 +85,7 @@ def test_predefined_tensor(self):
8285 self .assertTrue (torch .all (result [0 , 8 :11 ] == 2 ))
8386 self .assertTrue (torch .all (result [0 , 12 :15 ] == 3 ))
8487
85- def test_reconstruction_2D (self ):
88+ def test_reconstruction_2d (self ):
8689 input_tensor = torch .randn (1 , 1 , 4 , 4 )
8790 down = SubpixelDownsample (spatial_dims = 2 , in_channels = 1 , scale_factor = 2 , conv_block = None )
8891 up = SubpixelUpsample (spatial_dims = 2 , in_channels = 4 , scale_factor = 2 , conv_block = None , apply_pad_pool = False )
@@ -91,7 +94,7 @@ def test_reconstruction_2D(self):
9194 reconstructed = up (downsampled )
9295 self .assertTrue (torch .allclose (input_tensor , reconstructed , rtol = 1e-5 ))
9396
94- def test_reconstruction_3D (self ):
97+ def test_reconstruction_3d (self ):
9598 input_tensor = torch .randn (1 , 1 , 4 , 4 , 4 )
9699 down = SubpixelDownsample (spatial_dims = 3 , in_channels = 1 , scale_factor = 2 , conv_block = None )
97100 up = SubpixelUpsample (spatial_dims = 3 , in_channels = 4 , scale_factor = 2 , conv_block = None , apply_pad_pool = False )
@@ -135,7 +138,7 @@ def test_pre_post_conv(self):
135138 self .assertEqual (result .shape , (1 , 16 , 8 , 8 ))
136139
137140 def test_pixelunshuffle_equivalence (self ):
138- class DownSample_local (torch .nn .Module ):
141+ class DownSampleLocal (torch .nn .Module ):
139142 def __init__ (self , n_feat : int ):
140143 super ().__init__ ()
141144 self .conv = torch .nn .Conv2d (n_feat , n_feat // 2 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
@@ -158,7 +161,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
158161 pre_conv = fix_weight_conv ,
159162 )
160163
161- local_down = DownSample_local (n_feat )
164+ local_down = DownSampleLocal (n_feat )
162165 local_down .conv .weight .data = fix_weight_conv .weight .data .clone ()
163166
164167 with eval_mode (monai_down ), eval_mode (local_down ):
0 commit comments