124124class TestCheckpointUNet (unittest .TestCase ):
125125 @parameterized .expand (CASES )
126126 def test_shape (self , input_param , input_shape , expected_shape ):
127+ """Validate CheckpointUNet output shapes across configurations.
128+
129+ Args:
130+ input_param: Mapping of constructor kwargs for the network under test.
131+ input_shape: Shape tuple for the synthetic input tensor.
132+ expected_shape: Expected output tensor shape.
133+ """
127134 net = CheckpointUNet (** input_param ).to (device )
128135 with eval_mode (net ):
129136 result = net .forward (torch .randn (input_shape ).to (device ))
130137 self .assertEqual (result .shape , expected_shape )
131138
132139 def test_script (self ):
133- """
134- TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module.
135- To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable),
136- rather than the CheckpointUNet wrapper that uses checkpointing internals.
137- """
140+ """Script the baseline UNet to maintain TorchScript coverage."""
138141 net = UNet (
139142 spatial_dims = 2 , in_channels = 1 , out_channels = 3 , channels = (16 , 32 , 64 ), strides = (2 , 2 ), num_res_units = 0
140143 )
141144 test_data = torch .randn (16 , 1 , 32 , 32 )
142145 test_script_save (net , test_data )
143146
144147 def test_checkpointing_equivalence_eval (self ):
145- """Ensure that CheckpointUNet matches standard UNet in eval mode ( checkpointing inactive) ."""
148+ """Confirm eval parity when checkpointing is inactive."""
146149 params = dict (
147150 spatial_dims = 2 , in_channels = 1 , out_channels = 2 , channels = (8 , 16 , 32 ), strides = (2 , 2 ), num_res_units = 1
148151 )
@@ -168,7 +171,7 @@ def test_checkpointing_equivalence_eval(self):
168171 self .assertLess (diff , 1e-3 , f"Eval-mode outputs differ more than expected (mean abs diff={ diff :.6f} )" )
169172
170173 def test_checkpointing_activates_training (self ):
171- """Ensure checkpointing triggers recomputation under training and gradients propagate ."""
174+ """Verify checkpointing recomputes activations during training."""
172175 params = dict (
173176 spatial_dims = 2 , in_channels = 1 , out_channels = 1 , channels = (8 , 16 , 32 ), strides = (2 , 2 ), num_res_units = 1
174177 )
0 commit comments