Skip to content

Commit 84c0f48

Browse files
author
Fabio Ferreira
committed
test: add checkpoint unet test
1 parent 43dec88 commit 84c0f48

File tree

2 files changed

+219
-1
lines changed

2 files changed

+219
-1
lines changed

monai/networks/nets/unet.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,21 @@ class CheckpointUNet(UNet):
303303
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304304
305305
See `UNet` for constructor arguments. During training with gradients enabled,
306-
intermediate activations inside encoderdecoder connections are recomputed in
306+
intermediate activations inside encoder-decoder connections are recomputed in
307307
the backward pass to reduce peak memory usage at the cost of extra compute.
308308
"""
309309

310310
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
311321
subblock = ActivationCheckpointWrapper(subblock)
312322
down_path = ActivationCheckpointWrapper(down_path)
313323
up_path = ActivationCheckpointWrapper(up_path)
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.layers import Act, Norm
21+
from monai.networks.nets.unet import CheckpointUNet
22+
from tests.test_utils import test_script_save
23+
24+
device = "cuda" if torch.cuda.is_available() else "cpu"
25+
26+
TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
27+
{
28+
"spatial_dims": 2,
29+
"in_channels": 1,
30+
"out_channels": 3,
31+
"channels": (16, 32, 64),
32+
"strides": (2, 2),
33+
"num_res_units": 0,
34+
},
35+
(16, 1, 32, 32),
36+
(16, 3, 32, 32),
37+
]
38+
39+
TEST_CASE_1 = [ # single channel 2D, batch 16
40+
{
41+
"spatial_dims": 2,
42+
"in_channels": 1,
43+
"out_channels": 3,
44+
"channels": (16, 32, 64),
45+
"strides": (2, 2),
46+
"num_res_units": 1,
47+
},
48+
(16, 1, 32, 32),
49+
(16, 3, 32, 32),
50+
]
51+
52+
TEST_CASE_2 = [ # single channel 3D, batch 16
53+
{
54+
"spatial_dims": 3,
55+
"in_channels": 1,
56+
"out_channels": 3,
57+
"channels": (16, 32, 64),
58+
"strides": (2, 2),
59+
"num_res_units": 1,
60+
},
61+
(16, 1, 32, 24, 48),
62+
(16, 3, 32, 24, 48),
63+
]
64+
65+
TEST_CASE_3 = [ # 4-channel 3D, batch 16
66+
{
67+
"spatial_dims": 3,
68+
"in_channels": 4,
69+
"out_channels": 3,
70+
"channels": (16, 32, 64),
71+
"strides": (2, 2),
72+
"num_res_units": 1,
73+
},
74+
(16, 4, 32, 64, 48),
75+
(16, 3, 32, 64, 48),
76+
]
77+
78+
TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
79+
{
80+
"spatial_dims": 3,
81+
"in_channels": 4,
82+
"out_channels": 3,
83+
"channels": (16, 32, 64),
84+
"strides": (2, 2),
85+
"num_res_units": 1,
86+
"norm": Norm.BATCH,
87+
},
88+
(16, 4, 32, 64, 48),
89+
(16, 3, 32, 64, 48),
90+
]
91+
92+
TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
93+
{
94+
"spatial_dims": 3,
95+
"in_channels": 4,
96+
"out_channels": 3,
97+
"channels": (16, 32, 64),
98+
"strides": (2, 2),
99+
"num_res_units": 1,
100+
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
101+
"adn_ordering": "NA",
102+
},
103+
(16, 4, 32, 64, 48),
104+
(16, 3, 32, 64, 48),
105+
]
106+
107+
TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
108+
{
109+
"spatial_dims": 3,
110+
"in_channels": 4,
111+
"out_channels": 3,
112+
"channels": (16, 32, 64),
113+
"strides": (2, 2),
114+
"num_res_units": 1,
115+
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
116+
},
117+
(16, 4, 32, 64, 48),
118+
(16, 3, 32, 64, 48),
119+
]
120+
121+
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]
122+
123+
ILL_CASES = [
124+
[
125+
{ # len(channels) < 2
126+
"spatial_dims": 2,
127+
"in_channels": 1,
128+
"out_channels": 3,
129+
"channels": (16,),
130+
"strides": (2, 2),
131+
"num_res_units": 0,
132+
}
133+
],
134+
[
135+
{ # len(strides) < len(channels) - 1
136+
"spatial_dims": 2,
137+
"in_channels": 1,
138+
"out_channels": 3,
139+
"channels": (8, 8, 8),
140+
"strides": (2,),
141+
"num_res_units": 0,
142+
}
143+
],
144+
[
145+
{ # len(kernel_size) = 3, spatial_dims = 2
146+
"spatial_dims": 2,
147+
"in_channels": 1,
148+
"out_channels": 3,
149+
"channels": (8, 8, 8),
150+
"strides": (2, 2),
151+
"kernel_size": (3, 3, 3),
152+
}
153+
],
154+
[
155+
{ # len(up_kernel_size) = 2, spatial_dims = 3
156+
"spatial_dims": 3,
157+
"in_channels": 1,
158+
"out_channels": 3,
159+
"channels": (8, 8, 8),
160+
"strides": (2, 2),
161+
"up_kernel_size": (3, 3),
162+
}
163+
],
164+
]
165+
166+
167+
class TestUNET(unittest.TestCase):
168+
@parameterized.expand(CASES)
169+
def test_shape(self, input_param, input_shape, expected_shape):
170+
net = CheckpointUNet(**input_param).to(device)
171+
with eval_mode(net):
172+
result = net.forward(torch.randn(input_shape).to(device))
173+
self.assertEqual(result.shape, expected_shape)
174+
175+
def test_script(self):
176+
net = CheckpointUNet(
177+
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
178+
)
179+
test_data = torch.randn(16, 1, 32, 32)
180+
test_script_save(net, test_data)
181+
182+
def test_script_without_running_stats(self):
183+
net = CheckpointUNet(
184+
spatial_dims=2,
185+
in_channels=1,
186+
out_channels=3,
187+
channels=(16, 32, 64),
188+
strides=(2, 2),
189+
num_res_units=0,
190+
norm=("batch", {"track_running_stats": False}),
191+
)
192+
test_data = torch.randn(16, 1, 16, 4)
193+
test_script_save(net, test_data)
194+
195+
def test_ill_input_shape(self):
196+
net = CheckpointUNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2))
197+
with eval_mode(net):
198+
with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"):
199+
net.forward(torch.randn(2, 1, 16, 5))
200+
201+
@parameterized.expand(ILL_CASES)
202+
def test_ill_input_hyper_params(self, input_param):
203+
with self.assertRaises(ValueError):
204+
_ = CheckpointUNet(**input_param)
205+
206+
207+
if __name__ == "__main__":
208+
unittest.main()

0 commit comments

Comments
 (0)