Skip to content

Commit 64b203d

Browse files
committed
Update args naming in unit restormer test for consistency with suggested changes
Signed-off-by: tisalon <[email protected]>
1 parent 78ce56b commit 64b203d

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

tests/test_restormer.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
# limitations under the License.
1111

1212
from __future__ import annotations
13+
import os
14+
import sys
1315

16+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../"))
1417
import unittest
1518

1619
import torch
@@ -20,15 +23,15 @@
2023
from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer
2124

2225
TEST_CASES_TRANSFORMER = [
23-
# [spatial_dims, dim, num_heads, ffn_factor, bias, norm_type, flash_attn, input_shape]
24-
[2, 48, 8, 2.66, True, "WithBias", False, (2, 48, 64, 64)],
25-
[2, 96, 8, 2.66, False, "BiasFree", False, (2, 96, 32, 32)],
26-
[3, 48, 4, 2.66, True, "WithBias", False, (2, 48, 32, 32, 32)],
27-
[3, 96, 8, 2.66, False, "BiasFree", True, (2, 96, 16, 16, 16)],
26+
# [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]
27+
[2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)],
28+
[2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)],
29+
[3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)],
30+
[3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)],
2831
]
2932

3033
TEST_CASES_PATCHEMBED = [
31-
# spatial_dims, in_c, embed_dim, input_shape, expected_shape
34+
# spatial_dims, in_channels, embed_dim, input_shape, expected_shape
3235
[2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)],
3336
[2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)],
3437
[3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)],
@@ -52,7 +55,7 @@
5255
[
5356
{
5457
"spatial_dims": 2,
55-
"inp_channels": 1,
58+
"in_channels": 1,
5659
"out_channels": 1,
5760
"dim": 48,
5861
"num_blocks": config["num_blocks"],
@@ -67,9 +70,9 @@
6770
[
6871
{
6972
"spatial_dims": 3,
70-
"inp_channels": 1,
73+
"in_channels": 1,
7174
"out_channels": 1,
72-
"dim": 48,
75+
"dim": 16,
7376
"num_blocks": config["num_blocks"],
7477
"heads": config["heads"],
7578
"num_refinement_blocks": 2,
@@ -85,14 +88,14 @@
8588
class TestMDTATransformerBlock(unittest.TestCase):
8689

8790
@parameterized.expand(TEST_CASES_TRANSFORMER)
88-
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flash, shape):
91+
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
8992
block = MDTATransformerBlock(
9093
spatial_dims=spatial_dims,
9194
dim=dim,
9295
num_heads=heads,
9396
ffn_expansion_factor=ffn_factor,
9497
bias=bias,
95-
LayerNorm_type=norm_type,
98+
layer_norm_use_bias=layer_norm_use_bias,
9699
flash_attention=flash,
97100
)
98101
with eval_mode(block):
@@ -104,8 +107,8 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flas
104107
class TestOverlapPatchEmbed(unittest.TestCase):
105108

106109
@parameterized.expand(TEST_CASES_PATCHEMBED)
107-
def test_shape(self, spatial_dims, in_c, embed_dim, input_shape, expected_shape):
108-
net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_c=in_c, embed_dim=embed_dim)
110+
def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):
111+
net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)
109112
with eval_mode(net):
110113
result = net(torch.randn(input_shape))
111114
self.assertEqual(result.shape, expected_shape)
@@ -121,12 +124,12 @@ def test_shape(self, input_param, input_shape, expected_shape):
121124
self.assertEqual(result.shape, expected_shape)
122125

123126
def test_small_input_error_2d(self):
124-
net = Restormer(spatial_dims=2, inp_channels=1, out_channels=1)
127+
net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)
125128
with self.assertRaises(AssertionError):
126129
net(torch.randn(1, 1, 8, 8))
127130

128131
def test_small_input_error_3d(self):
129-
net = Restormer(spatial_dims=3, inp_channels=1, out_channels=1)
132+
net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)
130133
with self.assertRaises(AssertionError):
131134
net(torch.randn(1, 1, 8, 8, 8))
132135

0 commit comments

Comments
 (0)