1717from parameterized import parameterized
1818
1919from monai .networks import eval_mode
20- from monai .networks .blocks import MaxAvgPool
20+ from monai .networks .blocks import DownSample , MaxAvgPool , SubpixelDownsample , SubpixelUpsample
21+ from monai .utils import optional_import
22+
23+ einops , has_einops = optional_import ("einops" )
2124
2225TEST_CASES = [
2326 [{"spatial_dims" : 2 , "kernel_size" : 2 }, (7 , 4 , 64 , 48 ), (7 , 8 , 32 , 24 )], # 4-channel 2D, batch 7
3538 ],
3639]
3740
41+ TEST_CASES_SUBPIXEL = [
42+ [{"spatial_dims" : 2 , "in_channels" : 1 , "scale_factor" : 2 }, (1 , 1 , 8 , 8 ), (1 , 4 , 4 , 4 )],
43+ [{"spatial_dims" : 3 , "in_channels" : 2 , "scale_factor" : 2 }, (1 , 2 , 8 , 8 , 8 ), (1 , 16 , 4 , 4 , 4 )],
44+ [{"spatial_dims" : 1 , "in_channels" : 3 , "scale_factor" : 2 }, (1 , 3 , 8 ), (1 , 6 , 4 )],
45+ ]
46+
47+ TEST_CASES_DOWNSAMPLE = [
48+ [{"spatial_dims" : 2 , "in_channels" : 4 , "mode" : "conv" }, (1 , 4 , 16 , 16 ), (1 , 4 , 8 , 8 )],
49+ [{"spatial_dims" : 2 , "in_channels" : 4 , "out_channels" : 8 , "mode" : "convgroup" }, (1 , 4 , 16 , 16 ), (1 , 8 , 8 , 8 )],
50+ [{"spatial_dims" : 3 , "in_channels" : 2 , "mode" : "maxpool" }, (1 , 2 , 16 , 16 , 16 ), (1 , 2 , 8 , 8 , 8 )],
51+ [{"spatial_dims" : 2 , "in_channels" : 4 , "mode" : "avgpool" }, (1 , 4 , 16 , 16 ), (1 , 4 , 8 , 8 )],
52+ [{"spatial_dims" : 2 , "in_channels" : 1 , "mode" : "pixelunshuffle" }, (1 , 1 , 16 , 16 ), (1 , 4 , 8 , 8 )],
53+ ]
54+
3855
3956class TestMaxAvgPool (unittest .TestCase ):
4057
@@ -46,5 +63,121 @@ def test_shape(self, input_param, input_shape, expected_shape):
4663 self .assertEqual (result .shape , expected_shape )
4764
4865
66+ class TestSubpixelDownsample (unittest .TestCase ):
67+
68+ @parameterized .expand (TEST_CASES_SUBPIXEL )
69+ def test_shape (self , input_param , input_shape , expected_shape ):
70+ downsampler = SubpixelDownsample (** input_param )
71+ with eval_mode (downsampler ):
72+ result = downsampler (torch .randn (input_shape ))
73+ self .assertEqual (result .shape , expected_shape )
74+
75+ def test_predefined_tensor (self ):
76+ test_tensor = torch .arange (4 ).view (4 , 1 , 1 ).repeat (1 , 4 , 4 )
77+ test_tensor = test_tensor .unsqueeze (0 )
78+
79+ downsampler = SubpixelDownsample (spatial_dims = 2 , in_channels = 1 , scale_factor = 2 , conv_block = None )
80+ with eval_mode (downsampler ):
81+ result = downsampler (test_tensor )
82+ self .assertEqual (result .shape , (1 , 16 , 2 , 2 ))
83+ self .assertTrue (torch .all (result [0 , 0 :3 ] == 0 ))
84+ self .assertTrue (torch .all (result [0 , 4 :7 ] == 1 ))
85+ self .assertTrue (torch .all (result [0 , 8 :11 ] == 2 ))
86+ self .assertTrue (torch .all (result [0 , 12 :15 ] == 3 ))
87+
88+ def test_reconstruction_2d (self ):
89+ input_tensor = torch .randn (1 , 1 , 4 , 4 )
90+ down = SubpixelDownsample (spatial_dims = 2 , in_channels = 1 , scale_factor = 2 , conv_block = None )
91+ up = SubpixelUpsample (spatial_dims = 2 , in_channels = 4 , scale_factor = 2 , conv_block = None , apply_pad_pool = False )
92+ with eval_mode (down ), eval_mode (up ):
93+ downsampled = down (input_tensor )
94+ reconstructed = up (downsampled )
95+ self .assertTrue (torch .allclose (input_tensor , reconstructed , rtol = 1e-5 ))
96+
97+ def test_reconstruction_3d (self ):
98+ input_tensor = torch .randn (1 , 1 , 4 , 4 , 4 )
99+ down = SubpixelDownsample (spatial_dims = 3 , in_channels = 1 , scale_factor = 2 , conv_block = None )
100+ up = SubpixelUpsample (spatial_dims = 3 , in_channels = 4 , scale_factor = 2 , conv_block = None , apply_pad_pool = False )
101+ with eval_mode (down ), eval_mode (up ):
102+ downsampled = down (input_tensor )
103+ reconstructed = up (downsampled )
104+ self .assertTrue (torch .allclose (input_tensor , reconstructed , rtol = 1e-5 ))
105+
106+ def test_invalid_spatial_size (self ):
107+ downsampler = SubpixelDownsample (spatial_dims = 2 , in_channels = 1 , scale_factor = 2 )
108+ with self .assertRaises (ValueError ):
109+ downsampler (torch .randn (1 , 1 , 3 , 4 ))
110+
111+ def test_custom_conv_block (self ):
112+ custom_conv = torch .nn .Conv2d (1 , 2 , kernel_size = 3 , padding = 1 )
113+ downsampler = SubpixelDownsample (spatial_dims = 2 , in_channels = 1 , scale_factor = 2 , conv_block = custom_conv )
114+ with eval_mode (downsampler ):
115+ result = downsampler (torch .randn (1 , 1 , 4 , 4 ))
116+ self .assertEqual (result .shape , (1 , 8 , 2 , 2 ))
117+
118+
119+ class TestDownSample (unittest .TestCase ):
120+ @parameterized .expand (TEST_CASES_DOWNSAMPLE )
121+ def test_shape (self , input_param , input_shape , expected_shape ):
122+ net = DownSample (** input_param )
123+ with eval_mode (net ):
124+ result = net (torch .randn (input_shape ))
125+ self .assertEqual (result .shape , expected_shape )
126+
127+ def test_pre_post_conv (self ):
128+ net = DownSample (
129+ spatial_dims = 2 ,
130+ in_channels = 4 ,
131+ out_channels = 8 ,
132+ mode = "maxpool" ,
133+ pre_conv = "default" ,
134+ post_conv = torch .nn .Conv2d (8 , 16 , 1 ),
135+ )
136+ with eval_mode (net ):
137+ result = net (torch .randn (1 , 4 , 16 , 16 ))
138+ self .assertEqual (result .shape , (1 , 16 , 8 , 8 ))
139+
140+ def test_pixelunshuffle_equivalence (self ):
141+ class DownSampleLocal (torch .nn .Module ):
142+ def __init__ (self , n_feat : int ):
143+ super ().__init__ ()
144+ self .conv = torch .nn .Conv2d (n_feat , n_feat // 2 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
145+ self .pixelunshuffle = torch .nn .PixelUnshuffle (2 )
146+
147+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
148+ x = self .conv (x )
149+ return self .pixelunshuffle (x )
150+
151+ n_feat = 2
152+ x = torch .randn (1 , n_feat , 64 , 64 )
153+
154+ fix_weight_conv = torch .nn .Conv2d (n_feat , n_feat // 2 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False )
155+
156+ monai_down = DownSample (
157+ spatial_dims = 2 ,
158+ in_channels = n_feat ,
159+ out_channels = n_feat // 2 ,
160+ mode = "pixelunshuffle" ,
161+ pre_conv = fix_weight_conv ,
162+ )
163+
164+ local_down = DownSampleLocal (n_feat )
165+ local_down .conv .weight .data = fix_weight_conv .weight .data .clone ()
166+
167+ with eval_mode (monai_down ), eval_mode (local_down ):
168+ out_monai = monai_down (x )
169+ out_local = local_down (x )
170+
171+ self .assertTrue (torch .allclose (out_monai , out_local , rtol = 1e-5 ))
172+
173+ def test_invalid_mode (self ):
174+ with self .assertRaises (ValueError ):
175+ DownSample (spatial_dims = 2 , in_channels = 4 , mode = "invalid" )
176+
177+ def test_missing_channels (self ):
178+ with self .assertRaises (ValueError ):
179+ DownSample (spatial_dims = 2 , mode = "conv" )
180+
181+
49182if __name__ == "__main__" :
50183 unittest .main ()
0 commit comments