Skip to content

Commit 55da640

Browse files
committed
relocate test in the correct place
1 parent 0b0e4df commit 55da640

File tree

4 files changed

+398
-0
lines changed

4 files changed

+398
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from unittest import skipUnless
16+
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.networks import eval_mode
21+
from monai.networks.blocks.cablock import CABlock, FeedForward
22+
from monai.utils import optional_import
23+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
24+
25+
einops, has_einops = optional_import("einops")
26+
27+
28+
TEST_CASES_CAB = []
29+
for spatial_dims in [2, 3]:
30+
for dim in [32, 64, 128]:
31+
for num_heads in [2, 4, 8]:
32+
for bias in [True, False]:
33+
test_case = [
34+
{
35+
"spatial_dims": spatial_dims,
36+
"dim": dim,
37+
"num_heads": num_heads,
38+
"bias": bias,
39+
"flash_attention": False,
40+
},
41+
(2, dim, *([16] * spatial_dims)),
42+
(2, dim, *([16] * spatial_dims)),
43+
]
44+
TEST_CASES_CAB.append(test_case)
45+
46+
47+
TEST_CASES_FEEDFORWARD = [
48+
# Test different spatial dims, dimensions and expansion factors
49+
[{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)],
50+
[{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)],
51+
[{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)],
52+
]
53+
54+
55+
class TestFeedForward(unittest.TestCase):
56+
57+
@parameterized.expand(TEST_CASES_FEEDFORWARD)
58+
def test_shape(self, input_param, input_shape):
59+
net = FeedForward(**input_param)
60+
with eval_mode(net):
61+
result = net(torch.randn(input_shape))
62+
self.assertEqual(result.shape, input_shape)
63+
64+
def test_gating_mechanism(self):
65+
net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True)
66+
x = torch.ones(1, 32, 16, 16)
67+
out = net(x)
68+
self.assertNotEqual(torch.sum(out), torch.sum(x))
69+
70+
71+
class TestCABlock(unittest.TestCase):
72+
73+
@parameterized.expand(TEST_CASES_CAB)
74+
@skipUnless(has_einops, "Requires einops")
75+
def test_shape(self, input_param, input_shape, expected_shape):
76+
net = CABlock(**input_param)
77+
with eval_mode(net):
78+
result = net(torch.randn(input_shape))
79+
self.assertEqual(result.shape, expected_shape)
80+
81+
@skipUnless(has_einops, "Requires einops")
82+
def test_invalid_spatial_dims(self):
83+
with self.assertRaises(ValueError):
84+
CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True)
85+
86+
@SkipIfBeforePyTorchVersion((2, 0))
87+
@skipUnless(has_einops, "Requires einops")
88+
def test_flash_attention(self):
89+
device = "cuda" if torch.cuda.is_available() else "cpu"
90+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
91+
x = torch.randn(2, 64, 32, 32).to(device)
92+
output = block(x)
93+
self.assertEqual(output.shape, x.shape)
94+
95+
@skipUnless(has_einops, "Requires einops")
96+
def test_temperature_parameter(self):
97+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
98+
self.assertTrue(isinstance(block.temperature, torch.nn.Parameter))
99+
self.assertEqual(block.temperature.shape, (4, 1, 1))
100+
101+
@skipUnless(has_einops, "Requires einops")
102+
def test_qkv_transformation_2d(self):
103+
block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)
104+
x = torch.randn(2, 64, 32, 32)
105+
qkv = block.qkv(x)
106+
self.assertEqual(qkv.shape, (2, 192, 32, 32))
107+
108+
@skipUnless(has_einops, "Requires einops")
109+
def test_qkv_transformation_3d(self):
110+
block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True)
111+
x = torch.randn(2, 64, 16, 16, 16)
112+
qkv = block.qkv(x)
113+
self.assertEqual(qkv.shape, (2, 192, 16, 16, 16))
114+
115+
@SkipIfBeforePyTorchVersion((2, 0))
116+
@skipUnless(has_einops, "Requires einops")
117+
def test_flash_vs_normal_attention(self):
118+
device = "cuda" if torch.cuda.is_available() else "cpu"
119+
block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)
120+
block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device)
121+
122+
block_normal.load_state_dict(block_flash.state_dict())
123+
124+
x = torch.randn(2, 64, 32, 32).to(device)
125+
with torch.no_grad():
126+
out_flash = block_flash(x)
127+
out_normal = block_normal(x)
128+
129+
assert_allclose(out_flash, out_normal, atol=1e-4)
130+
131+
@skipUnless(has_einops, "Requires einops")
132+
def test_deterministic_small_input(self):
133+
block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False)
134+
with torch.no_grad():
135+
block.qkv.conv.weight.data.fill_(1.0)
136+
block.qkv_dwconv.conv.weight.data.fill_(1.0)
137+
block.temperature.data.fill_(1.0)
138+
block.project_out.conv.weight.data.fill_(1.0)
139+
140+
x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32)
141+
142+
output = block(x)
143+
# Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72
144+
expected = torch.full_like(x, 72.0)
145+
146+
assert_allclose(output, expected, atol=1e-6)
147+
148+
149+
if __name__ == "__main__":
150+
unittest.main()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.blocks import MaxAvgPool
21+
22+
TEST_CASES = [
23+
[{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7
24+
[{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16
25+
[{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16
26+
[ # 4-channel 3D, batch 16
27+
{"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True},
28+
(16, 4, 32, 24, 48),
29+
(16, 8, 11, 8, 16),
30+
],
31+
[ # 1-channel 3D, batch 16
32+
{"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False},
33+
(16, 1, 32, 24, 48),
34+
(16, 2, 10, 8, 16),
35+
],
36+
]
37+
38+
39+
class TestMaxAvgPool(unittest.TestCase):
40+
41+
@parameterized.expand(TEST_CASES)
42+
def test_shape(self, input_param, input_shape, expected_shape):
43+
net = MaxAvgPool(**input_param)
44+
with eval_mode(net):
45+
result = net(torch.randn(input_shape))
46+
self.assertEqual(result.shape, expected_shape)
47+
48+
49+
if __name__ == "__main__":
50+
unittest.main()
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from unittest import skipUnless
16+
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.networks import eval_mode
21+
from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer
22+
from monai.utils import optional_import
23+
24+
einops, has_einops = optional_import("einops")
25+
26+
TEST_CASES_TRANSFORMER = [
27+
# [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]
28+
[2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)],
29+
[2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)],
30+
[3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)],
31+
[3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)],
32+
]
33+
34+
TEST_CASES_PATCHEMBED = [
35+
# spatial_dims, in_channels, embed_dim, input_shape, expected_shape
36+
[2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)],
37+
[2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)],
38+
[3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)],
39+
[3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)],
40+
]
41+
42+
RESTORMER_CONFIGS = [
43+
# 2-level architecture
44+
{"num_blocks": [1, 1], "heads": [1, 1]},
45+
{"num_blocks": [2, 1], "heads": [2, 1]},
46+
# 3-level architecture
47+
{"num_blocks": [1, 1, 1], "heads": [1, 1, 1]},
48+
{"num_blocks": [2, 1, 1], "heads": [2, 1, 1]},
49+
]
50+
51+
TEST_CASES_RESTORMER = []
52+
for config in RESTORMER_CONFIGS:
53+
# 2D cases
54+
TEST_CASES_RESTORMER.extend(
55+
[
56+
[
57+
{
58+
"spatial_dims": 2,
59+
"in_channels": 1,
60+
"out_channels": 1,
61+
"dim": 48,
62+
"num_blocks": config["num_blocks"],
63+
"heads": config["heads"],
64+
"num_refinement_blocks": 2,
65+
"ffn_expansion_factor": 1.5,
66+
},
67+
(2, 1, 64, 64),
68+
(2, 1, 64, 64),
69+
],
70+
# 3D cases
71+
[
72+
{
73+
"spatial_dims": 3,
74+
"in_channels": 1,
75+
"out_channels": 1,
76+
"dim": 16,
77+
"num_blocks": config["num_blocks"],
78+
"heads": config["heads"],
79+
"num_refinement_blocks": 2,
80+
"ffn_expansion_factor": 1.5,
81+
},
82+
(2, 1, 32, 32, 32),
83+
(2, 1, 32, 32, 32),
84+
],
85+
]
86+
)
87+
88+
89+
class TestMDTATransformerBlock(unittest.TestCase):
90+
91+
@skipUnless(has_einops, "Requires einops")
92+
@parameterized.expand(TEST_CASES_TRANSFORMER)
93+
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
94+
if flash and not torch.cuda.is_available():
95+
self.skipTest("Flash attention requires CUDA")
96+
block = MDTATransformerBlock(
97+
spatial_dims=spatial_dims,
98+
dim=dim,
99+
num_heads=heads,
100+
ffn_expansion_factor=ffn_factor,
101+
bias=bias,
102+
layer_norm_use_bias=layer_norm_use_bias,
103+
flash_attention=flash,
104+
)
105+
with eval_mode(block):
106+
x = torch.randn(shape)
107+
output = block(x)
108+
self.assertEqual(output.shape, x.shape)
109+
110+
111+
class TestOverlapPatchEmbed(unittest.TestCase):
112+
113+
@parameterized.expand(TEST_CASES_PATCHEMBED)
114+
def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):
115+
net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)
116+
with eval_mode(net):
117+
result = net(torch.randn(input_shape))
118+
self.assertEqual(result.shape, expected_shape)
119+
120+
121+
class TestRestormer(unittest.TestCase):
122+
123+
@skipUnless(has_einops, "Requires einops")
124+
@parameterized.expand(TEST_CASES_RESTORMER)
125+
def test_shape(self, input_param, input_shape, expected_shape):
126+
if input_param.get("flash_attention", False) and not torch.cuda.is_available():
127+
self.skipTest("Flash attention requires CUDA")
128+
net = Restormer(**input_param)
129+
with eval_mode(net):
130+
result = net(torch.randn(input_shape))
131+
self.assertEqual(result.shape, expected_shape)
132+
133+
@skipUnless(has_einops, "Requires einops")
134+
def test_small_input_error_2d(self):
135+
net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)
136+
with self.assertRaises(AssertionError):
137+
net(torch.randn(1, 1, 8, 8))
138+
139+
@skipUnless(has_einops, "Requires einops")
140+
def test_small_input_error_3d(self):
141+
net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)
142+
with self.assertRaises(AssertionError):
143+
net(torch.randn(1, 1, 8, 8, 8))
144+
145+
146+
if __name__ == "__main__":
147+
unittest.main()

0 commit comments

Comments
 (0)