Skip to content

Commit 3c2dbc6

Browse files
committed
Add DownSampleBlock missing tests, Signed-off-by: Santiago Cano-Muniz <[email protected]>,
I, Cano-Muniz, Santiago <[email protected]>, hereby add my Signed-off-by to this commit: 55da640
1 parent 55da640 commit 3c2dbc6

File tree

1 file changed

+134
-1
lines changed

1 file changed

+134
-1
lines changed

tests/networks/blocks/test_downsample_block.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from parameterized import parameterized
1818

1919
from monai.networks import eval_mode
20-
from monai.networks.blocks import MaxAvgPool
20+
from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample
21+
from monai.utils import optional_import
22+
23+
einops, has_einops = optional_import("einops")
2124

2225
TEST_CASES = [
2326
[{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7
@@ -35,6 +38,20 @@
3538
],
3639
]
3740

41+
TEST_CASES_SUBPIXEL = [
42+
[{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)],
43+
[{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)],
44+
[{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)],
45+
]
46+
47+
TEST_CASES_DOWNSAMPLE = [
48+
[{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)],
49+
[{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)],
50+
[{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)],
51+
[{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)],
52+
[{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)],
53+
]
54+
3855

3956
class TestMaxAvgPool(unittest.TestCase):
4057

@@ -46,5 +63,121 @@ def test_shape(self, input_param, input_shape, expected_shape):
4663
self.assertEqual(result.shape, expected_shape)
4764

4865

66+
class TestSubpixelDownsample(unittest.TestCase):
67+
68+
@parameterized.expand(TEST_CASES_SUBPIXEL)
69+
def test_shape(self, input_param, input_shape, expected_shape):
70+
downsampler = SubpixelDownsample(**input_param)
71+
with eval_mode(downsampler):
72+
result = downsampler(torch.randn(input_shape))
73+
self.assertEqual(result.shape, expected_shape)
74+
75+
def test_predefined_tensor(self):
76+
test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4)
77+
test_tensor = test_tensor.unsqueeze(0)
78+
79+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
80+
with eval_mode(downsampler):
81+
result = downsampler(test_tensor)
82+
self.assertEqual(result.shape, (1, 16, 2, 2))
83+
self.assertTrue(torch.all(result[0, 0:3] == 0))
84+
self.assertTrue(torch.all(result[0, 4:7] == 1))
85+
self.assertTrue(torch.all(result[0, 8:11] == 2))
86+
self.assertTrue(torch.all(result[0, 12:15] == 3))
87+
88+
def test_reconstruction_2d(self):
89+
input_tensor = torch.randn(1, 1, 4, 4)
90+
down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
91+
up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
92+
with eval_mode(down), eval_mode(up):
93+
downsampled = down(input_tensor)
94+
reconstructed = up(downsampled)
95+
self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
96+
97+
def test_reconstruction_3d(self):
98+
input_tensor = torch.randn(1, 1, 4, 4, 4)
99+
down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None)
100+
up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
101+
with eval_mode(down), eval_mode(up):
102+
downsampled = down(input_tensor)
103+
reconstructed = up(downsampled)
104+
self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
105+
106+
def test_invalid_spatial_size(self):
107+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2)
108+
with self.assertRaises(ValueError):
109+
downsampler(torch.randn(1, 1, 3, 4))
110+
111+
def test_custom_conv_block(self):
112+
custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1)
113+
downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv)
114+
with eval_mode(downsampler):
115+
result = downsampler(torch.randn(1, 1, 4, 4))
116+
self.assertEqual(result.shape, (1, 8, 2, 2))
117+
118+
119+
class TestDownSample(unittest.TestCase):
120+
@parameterized.expand(TEST_CASES_DOWNSAMPLE)
121+
def test_shape(self, input_param, input_shape, expected_shape):
122+
net = DownSample(**input_param)
123+
with eval_mode(net):
124+
result = net(torch.randn(input_shape))
125+
self.assertEqual(result.shape, expected_shape)
126+
127+
def test_pre_post_conv(self):
128+
net = DownSample(
129+
spatial_dims=2,
130+
in_channels=4,
131+
out_channels=8,
132+
mode="maxpool",
133+
pre_conv="default",
134+
post_conv=torch.nn.Conv2d(8, 16, 1),
135+
)
136+
with eval_mode(net):
137+
result = net(torch.randn(1, 4, 16, 16))
138+
self.assertEqual(result.shape, (1, 16, 8, 8))
139+
140+
def test_pixelunshuffle_equivalence(self):
141+
class DownSampleLocal(torch.nn.Module):
142+
def __init__(self, n_feat: int):
143+
super().__init__()
144+
self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
145+
self.pixelunshuffle = torch.nn.PixelUnshuffle(2)
146+
147+
def forward(self, x: torch.Tensor) -> torch.Tensor:
148+
x = self.conv(x)
149+
return self.pixelunshuffle(x)
150+
151+
n_feat = 2
152+
x = torch.randn(1, n_feat, 64, 64)
153+
154+
fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
155+
156+
monai_down = DownSample(
157+
spatial_dims=2,
158+
in_channels=n_feat,
159+
out_channels=n_feat // 2,
160+
mode="pixelunshuffle",
161+
pre_conv=fix_weight_conv,
162+
)
163+
164+
local_down = DownSampleLocal(n_feat)
165+
local_down.conv.weight.data = fix_weight_conv.weight.data.clone()
166+
167+
with eval_mode(monai_down), eval_mode(local_down):
168+
out_monai = monai_down(x)
169+
out_local = local_down(x)
170+
171+
self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5))
172+
173+
def test_invalid_mode(self):
174+
with self.assertRaises(ValueError):
175+
DownSample(spatial_dims=2, in_channels=4, mode="invalid")
176+
177+
def test_missing_channels(self):
178+
with self.assertRaises(ValueError):
179+
DownSample(spatial_dims=2, mode="conv")
180+
181+
49182
if __name__ == "__main__":
50183
unittest.main()

0 commit comments

Comments
 (0)