Skip to content

Commit a068c0e

Browse files
author
Fabio Ferreira
committed
chore: add test docstrings
Signed-off-by: Fabio Ferreira <[email protected]>
1 parent 41f000f commit a068c0e

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/networks/nets/test_checkpointunet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,25 +124,28 @@
124124
class TestCheckpointUNet(unittest.TestCase):
125125
@parameterized.expand(CASES)
126126
def test_shape(self, input_param, input_shape, expected_shape):
127+
"""Validate CheckpointUNet output shapes across configurations.
128+
129+
Args:
130+
input_param: Mapping of constructor kwargs for the network under test.
131+
input_shape: Shape tuple for the synthetic input tensor.
132+
expected_shape: Expected output tensor shape.
133+
"""
127134
net = CheckpointUNet(**input_param).to(device)
128135
with eval_mode(net):
129136
result = net.forward(torch.randn(input_shape).to(device))
130137
self.assertEqual(result.shape, expected_shape)
131138

132139
def test_script(self):
133-
"""
134-
TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module.
135-
To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable),
136-
rather than the CheckpointUNet wrapper that uses checkpointing internals.
137-
"""
140+
"""Script the baseline UNet to maintain TorchScript coverage."""
138141
net = UNet(
139142
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
140143
)
141144
test_data = torch.randn(16, 1, 32, 32)
142145
test_script_save(net, test_data)
143146

144147
def test_checkpointing_equivalence_eval(self):
145-
"""Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive)."""
148+
"""Confirm eval parity when checkpointing is inactive."""
146149
params = dict(
147150
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
148151
)
@@ -168,7 +171,7 @@ def test_checkpointing_equivalence_eval(self):
168171
self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})")
169172

170173
def test_checkpointing_activates_training(self):
171-
"""Ensure checkpointing triggers recomputation under training and gradients propagate."""
174+
"""Verify checkpointing recomputes activations during training."""
172175
params = dict(
173176
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
174177
)

0 commit comments

Comments
 (0)