Skip to content

Commit 1aa8e3c

Browse files
author
Fabio Ferreira
committed
fix: simplify test and make sure that checkpoint unet runs well in training
1 parent 5805515 commit 1aa8e3c

File tree

1 file changed

+46
-111
lines changed

1 file changed

+46
-111
lines changed

tests/networks/nets/test_checkpointunet.py

Lines changed: 46 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) MONAI Consortium
22
# Licensed under the Apache License, Version 2.0 (the "License");
3-
# you may not use this file except in compliance with the License.
3+
# You may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
55
# http://www.apache.org/licenses/LICENSE-2.0
66
# Unless required by applicable law or agreed to in writing, software
@@ -17,13 +17,12 @@
1717
from parameterized import parameterized
1818

1919
from monai.networks import eval_mode
20-
from monai.networks.layers import Act, Norm
21-
from monai.networks.nets.unet import CheckpointUNet
20+
from monai.networks.nets.unet import CheckpointUNet, UNet
2221
from tests.test_utils import test_script_save
2322

2423
device = "cuda" if torch.cuda.is_available() else "cpu"
2524

26-
TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
25+
TEST_CASE_0 = [
2726
{
2827
"spatial_dims": 2,
2928
"in_channels": 1,
@@ -36,7 +35,7 @@
3635
(16, 3, 32, 32),
3736
]
3837

39-
TEST_CASE_1 = [ # single channel 2D, batch 16
38+
TEST_CASE_1 = [
4039
{
4140
"spatial_dims": 2,
4241
"in_channels": 1,
@@ -49,7 +48,7 @@
4948
(16, 3, 32, 32),
5049
]
5150

52-
TEST_CASE_2 = [ # single channel 3D, batch 16
51+
TEST_CASE_2 = [
5352
{
5453
"spatial_dims": 3,
5554
"in_channels": 1,
@@ -62,7 +61,7 @@
6261
(16, 3, 32, 24, 48),
6362
]
6463

65-
TEST_CASE_3 = [ # 4-channel 3D, batch 16
64+
TEST_CASE_3 = [
6665
{
6766
"spatial_dims": 3,
6867
"in_channels": 4,
@@ -75,93 +74,7 @@
7574
(16, 3, 32, 64, 48),
7675
]
7776

78-
TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
79-
{
80-
"spatial_dims": 3,
81-
"in_channels": 4,
82-
"out_channels": 3,
83-
"channels": (16, 32, 64),
84-
"strides": (2, 2),
85-
"num_res_units": 1,
86-
"norm": Norm.BATCH,
87-
},
88-
(16, 4, 32, 64, 48),
89-
(16, 3, 32, 64, 48),
90-
]
91-
92-
TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
93-
{
94-
"spatial_dims": 3,
95-
"in_channels": 4,
96-
"out_channels": 3,
97-
"channels": (16, 32, 64),
98-
"strides": (2, 2),
99-
"num_res_units": 1,
100-
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
101-
"adn_ordering": "NA",
102-
},
103-
(16, 4, 32, 64, 48),
104-
(16, 3, 32, 64, 48),
105-
]
106-
107-
TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
108-
{
109-
"spatial_dims": 3,
110-
"in_channels": 4,
111-
"out_channels": 3,
112-
"channels": (16, 32, 64),
113-
"strides": (2, 2),
114-
"num_res_units": 1,
115-
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
116-
},
117-
(16, 4, 32, 64, 48),
118-
(16, 3, 32, 64, 48),
119-
]
120-
121-
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]
122-
123-
ILL_CASES = [
124-
[
125-
{ # len(channels) < 2
126-
"spatial_dims": 2,
127-
"in_channels": 1,
128-
"out_channels": 3,
129-
"channels": (16,),
130-
"strides": (2, 2),
131-
"num_res_units": 0,
132-
}
133-
],
134-
[
135-
{ # len(strides) < len(channels) - 1
136-
"spatial_dims": 2,
137-
"in_channels": 1,
138-
"out_channels": 3,
139-
"channels": (8, 8, 8),
140-
"strides": (2,),
141-
"num_res_units": 0,
142-
}
143-
],
144-
[
145-
{ # len(kernel_size) = 3, spatial_dims = 2
146-
"spatial_dims": 2,
147-
"in_channels": 1,
148-
"out_channels": 3,
149-
"channels": (8, 8, 8),
150-
"strides": (2, 2),
151-
"kernel_size": (3, 3, 3),
152-
}
153-
],
154-
[
155-
{ # len(up_kernel_size) = 2, spatial_dims = 3
156-
"spatial_dims": 3,
157-
"in_channels": 1,
158-
"out_channels": 3,
159-
"channels": (8, 8, 8),
160-
"strides": (2, 2),
161-
"up_kernel_size": (3, 3),
162-
}
163-
],
164-
]
77+
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]
16578

16679

16780
class TestCheckpointUNet(unittest.TestCase):
@@ -179,29 +92,51 @@ def test_script(self):
17992
test_data = torch.randn(16, 1, 32, 32)
18093
test_script_save(net, test_data)
18194

182-
def test_script_without_running_stats(self):
183-
net = CheckpointUNet(
184-
spatial_dims=2,
185-
in_channels=1,
186-
out_channels=3,
187-
channels=(16, 32, 64),
188-
strides=(2, 2),
189-
num_res_units=0,
190-
norm=("batch", {"track_running_stats": False}),
191-
)
192-
test_data = torch.randn(16, 1, 16, 4)
193-
test_script_save(net, test_data)
194-
19595
def test_ill_input_shape(self):
19696
net = CheckpointUNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2))
19797
with eval_mode(net):
19898
with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"):
19999
net.forward(torch.randn(2, 1, 16, 5))
200100

201-
@parameterized.expand(ILL_CASES)
202-
def test_ill_input_hyper_params(self, input_param):
203-
with self.assertRaises(ValueError):
204-
_ = CheckpointUNet(**input_param)
101+
def test_checkpointing_equivalence_eval(self):
102+
"""Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive)."""
103+
params = dict(
104+
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
105+
)
106+
107+
x = torch.randn(2, 1, 32, 32, device=device)
108+
109+
net_ckpt = CheckpointUNet(**params).to(device)
110+
net_plain = UNet(**params).to(device)
111+
112+
with eval_mode(net_ckpt), eval_mode(net_plain):
113+
y_ckpt = net_ckpt(x)
114+
y_plain = net_plain(x)
115+
116+
# checkpointing should not change outputs in eval mode
117+
self.assertTrue(torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5))
118+
119+
def test_checkpointing_activates_training(self):
120+
"""Ensure checkpointing triggers recomputation under training and gradients propagate."""
121+
params = dict(
122+
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
123+
)
124+
125+
net = CheckpointUNet(**params).to(device)
126+
net.train()
127+
128+
x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)
129+
y = net(x)
130+
loss = y.mean()
131+
loss.backward()
132+
133+
# gradient flow check
134+
grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)
135+
self.assertGreater(grad_norm.item(), 0.0)
136+
137+
# checkpointing should reduce activation memory use; we can't directly assert memory savings
138+
# but we can confirm no runtime errors and gradients propagate correctly
139+
self.assertIsNotNone(grad_norm)
205140

206141

207142
if __name__ == "__main__":

0 commit comments

Comments
 (0)