1010# limitations under the License.
1111
1212from __future__ import annotations
13+ import os
14+ import sys
1315
16+ sys .path .insert (0 , os .path .join (os .path .dirname (os .path .realpath (__file__ )), "../" ))
1417import unittest
1518
1619import torch
2023from monai .networks .nets .restormer import MDTATransformerBlock , OverlapPatchEmbed , Restormer
2124
2225TEST_CASES_TRANSFORMER = [
23- # [spatial_dims, dim, num_heads, ffn_factor, bias, norm_type , flash_attn, input_shape]
24- [2 , 48 , 8 , 2.66 , True , "WithBias" , False , (2 , 48 , 64 , 64 )],
25- [2 , 96 , 8 , 2.66 , False , "BiasFree" , False , (2 , 96 , 32 , 32 )],
26- [3 , 48 , 4 , 2.66 , True , "WithBias" , False , (2 , 48 , 32 , 32 , 32 )],
27- [3 , 96 , 8 , 2.66 , False , "BiasFree" , True , (2 , 96 , 16 , 16 , 16 )],
26+ # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias , flash_attn, input_shape]
27+ [2 , 48 , 8 , 2.66 , True , True , False , (2 , 48 , 64 , 64 )],
28+ [2 , 96 , 8 , 2.66 , False , False , False , (2 , 96 , 32 , 32 )],
29+ [3 , 48 , 4 , 2.66 , True , True , False , (2 , 48 , 32 , 32 , 32 )],
30+ [3 , 96 , 8 , 2.66 , False , False , True , (2 , 96 , 16 , 16 , 16 )],
2831]
2932
3033TEST_CASES_PATCHEMBED = [
31- # spatial_dims, in_c , embed_dim, input_shape, expected_shape
34+ # spatial_dims, in_channels , embed_dim, input_shape, expected_shape
3235 [2 , 1 , 48 , (2 , 1 , 64 , 64 ), (2 , 48 , 64 , 64 )],
3336 [2 , 3 , 96 , (2 , 3 , 32 , 32 ), (2 , 96 , 32 , 32 )],
3437 [3 , 1 , 48 , (2 , 1 , 32 , 32 , 32 ), (2 , 48 , 32 , 32 , 32 )],
5255 [
5356 {
5457 "spatial_dims" : 2 ,
55- "inp_channels " : 1 ,
58+ "in_channels " : 1 ,
5659 "out_channels" : 1 ,
5760 "dim" : 48 ,
5861 "num_blocks" : config ["num_blocks" ],
6770 [
6871 {
6972 "spatial_dims" : 3 ,
70- "inp_channels " : 1 ,
73+ "in_channels " : 1 ,
7174 "out_channels" : 1 ,
72- "dim" : 48 ,
75+ "dim" : 16 ,
7376 "num_blocks" : config ["num_blocks" ],
7477 "heads" : config ["heads" ],
7578 "num_refinement_blocks" : 2 ,
8588class TestMDTATransformerBlock (unittest .TestCase ):
8689
8790 @parameterized .expand (TEST_CASES_TRANSFORMER )
88- def test_shape (self , spatial_dims , dim , heads , ffn_factor , bias , norm_type , flash , shape ):
91+ def test_shape (self , spatial_dims , dim , heads , ffn_factor , bias , layer_norm_use_bias , flash , shape ):
8992 block = MDTATransformerBlock (
9093 spatial_dims = spatial_dims ,
9194 dim = dim ,
9295 num_heads = heads ,
9396 ffn_expansion_factor = ffn_factor ,
9497 bias = bias ,
95- LayerNorm_type = norm_type ,
98+ layer_norm_use_bias = layer_norm_use_bias ,
9699 flash_attention = flash ,
97100 )
98101 with eval_mode (block ):
@@ -104,8 +107,8 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flas
104107class TestOverlapPatchEmbed (unittest .TestCase ):
105108
106109 @parameterized .expand (TEST_CASES_PATCHEMBED )
107- def test_shape (self , spatial_dims , in_c , embed_dim , input_shape , expected_shape ):
108- net = OverlapPatchEmbed (spatial_dims = spatial_dims , in_c = in_c , embed_dim = embed_dim )
110+ def test_shape (self , spatial_dims , in_channels , embed_dim , input_shape , expected_shape ):
111+ net = OverlapPatchEmbed (spatial_dims = spatial_dims , in_channels = in_channels , embed_dim = embed_dim )
109112 with eval_mode (net ):
110113 result = net (torch .randn (input_shape ))
111114 self .assertEqual (result .shape , expected_shape )
@@ -121,12 +124,12 @@ def test_shape(self, input_param, input_shape, expected_shape):
121124 self .assertEqual (result .shape , expected_shape )
122125
123126 def test_small_input_error_2d (self ):
124- net = Restormer (spatial_dims = 2 , inp_channels = 1 , out_channels = 1 )
127+ net = Restormer (spatial_dims = 2 , in_channels = 1 , out_channels = 1 )
125128 with self .assertRaises (AssertionError ):
126129 net (torch .randn (1 , 1 , 8 , 8 ))
127130
128131 def test_small_input_error_3d (self ):
129- net = Restormer (spatial_dims = 3 , inp_channels = 1 , out_channels = 1 )
132+ net = Restormer (spatial_dims = 3 , in_channels = 1 , out_channels = 1 )
130133 with self .assertRaises (AssertionError ):
131134 net (torch .randn (1 , 1 , 8 , 8 , 8 ))
132135
0 commit comments