Skip to content

Commit c7b1af4

Browse files
committed
add optional_import to downsample block test
Signed-off-by: tisalon <[email protected]>
1 parent 6352ba9 commit c7b1af4

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_downsample_block.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
from monai.networks import eval_mode
2020
from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample
21+
from monai.utils import optional_import
22+
23+
einops, has_einops = optional_import("einops")
2124

2225
TEST_CASES = [
2326
[{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7
@@ -82,7 +85,7 @@ def test_predefined_tensor(self):
8285
self.assertTrue(torch.all(result[0, 8:11] == 2))
8386
self.assertTrue(torch.all(result[0, 12:15] == 3))
8487

85-
def test_reconstruction_2D(self):
88+
def test_reconstruction_2d(self):
8689
input_tensor = torch.randn(1, 1, 4, 4)
8790
down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)
8891
up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
@@ -91,7 +94,7 @@ def test_reconstruction_2D(self):
9194
reconstructed = up(downsampled)
9295
self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))
9396

94-
def test_reconstruction_3D(self):
97+
def test_reconstruction_3d(self):
9598
input_tensor = torch.randn(1, 1, 4, 4, 4)
9699
down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None)
97100
up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)
@@ -135,7 +138,7 @@ def test_pre_post_conv(self):
135138
self.assertEqual(result.shape, (1, 16, 8, 8))
136139

137140
def test_pixelunshuffle_equivalence(self):
138-
class DownSample_local(torch.nn.Module):
141+
class DownSampleLocal(torch.nn.Module):
139142
def __init__(self, n_feat: int):
140143
super().__init__()
141144
self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)
@@ -158,7 +161,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
158161
pre_conv=fix_weight_conv,
159162
)
160163

161-
local_down = DownSample_local(n_feat)
164+
local_down = DownSampleLocal(n_feat)
162165
local_down.conv.weight.data = fix_weight_conv.weight.data.clone()
163166

164167
with eval_mode(monai_down), eval_mode(local_down):

0 commit comments

Comments
 (0)