11# Copyright (c) MONAI Consortium
22# Licensed under the Apache License, Version 2.0 (the "License");
3- # you may not use this file except in compliance with the License.
3+ # You may not use this file except in compliance with the License.
44# You may obtain a copy of the License at
55# http://www.apache.org/licenses/LICENSE-2.0
66# Unless required by applicable law or agreed to in writing, software
1717from parameterized import parameterized
1818
1919from monai .networks import eval_mode
20- from monai .networks .layers import Act , Norm
21- from monai .networks .nets .unet import CheckpointUNet
20+ from monai .networks .nets .unet import CheckpointUNet , UNet
2221from tests .test_utils import test_script_save
2322
2423device = "cuda" if torch .cuda .is_available () else "cpu"
2524
26- TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
25+ TEST_CASE_0 = [
2726 {
2827 "spatial_dims" : 2 ,
2928 "in_channels" : 1 ,
3635 (16 , 3 , 32 , 32 ),
3736]
3837
39- TEST_CASE_1 = [ # single channel 2D, batch 16
38+ TEST_CASE_1 = [
4039 {
4140 "spatial_dims" : 2 ,
4241 "in_channels" : 1 ,
4948 (16 , 3 , 32 , 32 ),
5049]
5150
52- TEST_CASE_2 = [ # single channel 3D, batch 16
51+ TEST_CASE_2 = [
5352 {
5453 "spatial_dims" : 3 ,
5554 "in_channels" : 1 ,
6261 (16 , 3 , 32 , 24 , 48 ),
6362]
6463
65- TEST_CASE_3 = [ # 4-channel 3D, batch 16
64+ TEST_CASE_3 = [
6665 {
6766 "spatial_dims" : 3 ,
6867 "in_channels" : 4 ,
7574 (16 , 3 , 32 , 64 , 48 ),
7675]
7776
78- TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
79- {
80- "spatial_dims" : 3 ,
81- "in_channels" : 4 ,
82- "out_channels" : 3 ,
83- "channels" : (16 , 32 , 64 ),
84- "strides" : (2 , 2 ),
85- "num_res_units" : 1 ,
86- "norm" : Norm .BATCH ,
87- },
88- (16 , 4 , 32 , 64 , 48 ),
89- (16 , 3 , 32 , 64 , 48 ),
90- ]
91-
92- TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
93- {
94- "spatial_dims" : 3 ,
95- "in_channels" : 4 ,
96- "out_channels" : 3 ,
97- "channels" : (16 , 32 , 64 ),
98- "strides" : (2 , 2 ),
99- "num_res_units" : 1 ,
100- "act" : (Act .LEAKYRELU , {"negative_slope" : 0.2 }),
101- "adn_ordering" : "NA" ,
102- },
103- (16 , 4 , 32 , 64 , 48 ),
104- (16 , 3 , 32 , 64 , 48 ),
105- ]
106-
107- TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
108- {
109- "spatial_dims" : 3 ,
110- "in_channels" : 4 ,
111- "out_channels" : 3 ,
112- "channels" : (16 , 32 , 64 ),
113- "strides" : (2 , 2 ),
114- "num_res_units" : 1 ,
115- "act" : (torch .nn .LeakyReLU , {"negative_slope" : 0.2 }),
116- },
117- (16 , 4 , 32 , 64 , 48 ),
118- (16 , 3 , 32 , 64 , 48 ),
119- ]
120-
121- CASES = [TEST_CASE_0 , TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 , TEST_CASE_4 , TEST_CASE_5 , TEST_CASE_6 ]
122-
123- ILL_CASES = [
124- [
125- { # len(channels) < 2
126- "spatial_dims" : 2 ,
127- "in_channels" : 1 ,
128- "out_channels" : 3 ,
129- "channels" : (16 ,),
130- "strides" : (2 , 2 ),
131- "num_res_units" : 0 ,
132- }
133- ],
134- [
135- { # len(strides) < len(channels) - 1
136- "spatial_dims" : 2 ,
137- "in_channels" : 1 ,
138- "out_channels" : 3 ,
139- "channels" : (8 , 8 , 8 ),
140- "strides" : (2 ,),
141- "num_res_units" : 0 ,
142- }
143- ],
144- [
145- { # len(kernel_size) = 3, spatial_dims = 2
146- "spatial_dims" : 2 ,
147- "in_channels" : 1 ,
148- "out_channels" : 3 ,
149- "channels" : (8 , 8 , 8 ),
150- "strides" : (2 , 2 ),
151- "kernel_size" : (3 , 3 , 3 ),
152- }
153- ],
154- [
155- { # len(up_kernel_size) = 2, spatial_dims = 3
156- "spatial_dims" : 3 ,
157- "in_channels" : 1 ,
158- "out_channels" : 3 ,
159- "channels" : (8 , 8 , 8 ),
160- "strides" : (2 , 2 ),
161- "up_kernel_size" : (3 , 3 ),
162- }
163- ],
164- ]
77+ CASES = [TEST_CASE_0 , TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ]
16578
16679
16780class TestCheckpointUNet (unittest .TestCase ):
@@ -179,29 +92,51 @@ def test_script(self):
17992 test_data = torch .randn (16 , 1 , 32 , 32 )
18093 test_script_save (net , test_data )
18194
182- def test_script_without_running_stats (self ):
183- net = CheckpointUNet (
184- spatial_dims = 2 ,
185- in_channels = 1 ,
186- out_channels = 3 ,
187- channels = (16 , 32 , 64 ),
188- strides = (2 , 2 ),
189- num_res_units = 0 ,
190- norm = ("batch" , {"track_running_stats" : False }),
191- )
192- test_data = torch .randn (16 , 1 , 16 , 4 )
193- test_script_save (net , test_data )
194-
19595 def test_ill_input_shape (self ):
19696 net = CheckpointUNet (spatial_dims = 2 , in_channels = 1 , out_channels = 3 , channels = (16 , 32 , 64 ), strides = (2 , 2 ))
19797 with eval_mode (net ):
19898 with self .assertRaisesRegex (RuntimeError , "Sizes of tensors must match" ):
19999 net .forward (torch .randn (2 , 1 , 16 , 5 ))
200100
201- @parameterized .expand (ILL_CASES )
202- def test_ill_input_hyper_params (self , input_param ):
203- with self .assertRaises (ValueError ):
204- _ = CheckpointUNet (** input_param )
101+ def test_checkpointing_equivalence_eval (self ):
102+ """Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive)."""
103+ params = dict (
104+ spatial_dims = 2 , in_channels = 1 , out_channels = 2 , channels = (8 , 16 , 32 ), strides = (2 , 2 ), num_res_units = 1
105+ )
106+
107+ x = torch .randn (2 , 1 , 32 , 32 , device = device )
108+
109+ net_ckpt = CheckpointUNet (** params ).to (device )
110+ net_plain = UNet (** params ).to (device )
111+
112+ with eval_mode (net_ckpt ), eval_mode (net_plain ):
113+ y_ckpt = net_ckpt (x )
114+ y_plain = net_plain (x )
115+
116+ # checkpointing should not change outputs in eval mode
117+ self .assertTrue (torch .allclose (y_ckpt , y_plain , atol = 1e-6 , rtol = 1e-5 ))
118+
119+ def test_checkpointing_activates_training (self ):
120+ """Ensure checkpointing triggers recomputation under training and gradients propagate."""
121+ params = dict (
122+ spatial_dims = 2 , in_channels = 1 , out_channels = 1 , channels = (8 , 16 , 32 ), strides = (2 , 2 ), num_res_units = 1
123+ )
124+
125+ net = CheckpointUNet (** params ).to (device )
126+ net .train ()
127+
128+ x = torch .randn (2 , 1 , 32 , 32 , device = device , requires_grad = True )
129+ y = net (x )
130+ loss = y .mean ()
131+ loss .backward ()
132+
133+ # gradient flow check
134+ grad_norm = sum (p .grad .abs ().sum () for p in net .parameters () if p .grad is not None )
135+ self .assertGreater (grad_norm .item (), 0.0 )
136+
137+ # checkpointing should reduce activation memory use; we can't directly assert memory savings
138+ # but we can confirm no runtime errors and gradients propagate correctly
139+ self .assertIsNotNone (grad_norm )
205140
206141
207142if __name__ == "__main__" :
0 commit comments