Skip to content

Commit b8dde47

Browse files
authored
Merge branch 'dev' into warp-doc-differentiability-warning
2 parents 2eefa92 + 23c271e commit b8dde47

File tree

6 files changed

+276
-18
lines changed

6 files changed

+276
-18
lines changed

monai/apps/pathology/transforms/stain/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:
8585
v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)
8686

8787
# a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second
88-
if v_min[0] > v_max[0]:
88+
# Hematoxylin: high blue, lower red (low R/B ratio)
89+
# Eosin: high red, lower blue (high R/B ratio)
90+
eps = np.finfo(np.float32).eps
91+
v_min_rb_ratio = v_min[0] / (v_min[2] + eps)
92+
v_max_rb_ratio = v_max[0] / (v_max[2] + eps)
93+
if v_min_rb_ratio < v_max_rb_ratio:
8994
he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T
9095
else:
9196
he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2021
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2122
from monai.networks.layers.factories import Act, Norm
2223
from monai.networks.layers.simplelayers import SkipConnection
2324

24-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2526

2627

2728
class UNet(nn.Module):
@@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298299
return x
299300

300301

302+
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder-decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
310+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
321+
subblock = ActivationCheckpointWrapper(subblock)
322+
down_path = ActivationCheckpointWrapper(down_path)
323+
up_path = ActivationCheckpointWrapper(up_path)
324+
return super()._get_connection_block(down_path, up_path, subblock)
325+
326+
301327
Unet = UNet

tests/apps/pathology/transforms/test_pathology_he_stain.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# input pixels not uniformly filled, leading to two different stains extracted
4949
EXTRACT_STAINS_TEST_CASE_5 = [
5050
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
51-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
51+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
5252
]
5353

5454
# input pixels all transparent and below the beta absorbance threshold
@@ -68,7 +68,7 @@
6868
NORMALIZE_STAINS_TEST_CASE_4 = [
6969
{"target_he": np.full((3, 2), 1)},
7070
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
71-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
71+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
7272
]
7373

7474

@@ -135,7 +135,7 @@ def test_result_value(self, image, expected_data):
135135
[[0.18696113],[0],[0.98236734]] and
136136
[[0.70710677],[0],[0.70710677]] respectively
137137
- the resulting extracted stain should be
138-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
138+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
139139
"""
140140
if image is None:
141141
with self.assertRaises(TypeError):
@@ -206,17 +206,17 @@ def test_result_value(self, arguments, image, expected_data):
206206
207207
For test case 4:
208208
- For this non-uniformly filled image, the stain extracted should be
209-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
209+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
210210
ExtractHEStains class. Solving the linear least squares problem (since
211211
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
212-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
213-
[5.8022, 0, 0, 0, 0, 0]]
212+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
213+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
214214
- Normalizing the concentration matrix, taking the matrix product of the
215215
target stain matrix and the concentration matrix, using the inverse
216216
Beer-Lambert transform to obtain the RGB image from the absorbance
217217
image, and finally converting to uint8, we get that the stain normalized
218-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
219-
[[33, 33, 33], [33, 33, 33]]]
218+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
219+
[[85, 85, 85], [85, 85, 85]]]
220220
"""
221221
if image is None:
222222
with self.assertRaises(TypeError):

tests/apps/pathology/transforms/test_pathology_he_stain_dict.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# input pixels not uniformly filled, leading to two different stains extracted
4343
EXTRACT_STAINS_TEST_CASE_5 = [
4444
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
45-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
45+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
4646
]
4747

4848
# input pixels all transparent and below the beta absorbance threshold
@@ -62,7 +62,7 @@
6262
NORMALIZE_STAINS_TEST_CASE_4 = [
6363
{"target_he": np.full((3, 2), 1)},
6464
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
65-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
65+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
6666
]
6767

6868

@@ -129,7 +129,7 @@ def test_result_value(self, image, expected_data):
129129
[[0.18696113],[0],[0.98236734]] and
130130
[[0.70710677],[0],[0.70710677]] respectively
131131
- the resulting extracted stain should be
132-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
132+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
133133
"""
134134
key = "image"
135135
if image is None:
@@ -200,17 +200,17 @@ def test_result_value(self, arguments, image, expected_data):
200200
201201
For test case 4:
202202
- For this non-uniformly filled image, the stain extracted should be
203-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
203+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
204204
ExtractHEStains class. Solving the linear least squares problem (since
205205
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
206-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
207-
[5.8022, 0, 0, 0, 0, 0]]
206+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
207+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
208208
- Normalizing the concentration matrix, taking the matrix product of the
209209
target stain matrix and the concentration matrix, using the inverse
210210
Beer-Lambert transform to obtain the RGB image from the absorbance
211211
image, and finally converting to uint8, we get that the stain normalized
212-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
213-
[[33, 33, 33], [33, 33, 33]]]
212+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
213+
[[85, 85, 85], [85, 85, 85]]]
214214
"""
215215
key = "image"
216216
if image is None:
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.layers import Act, Norm
21+
from monai.networks.nets.unet import CheckpointUNet, UNet
22+
23+
device = "cuda" if torch.cuda.is_available() else "cpu"
24+
25+
TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
26+
{
27+
"spatial_dims": 2,
28+
"in_channels": 1,
29+
"out_channels": 3,
30+
"channels": (16, 32, 64),
31+
"strides": (2, 2),
32+
"num_res_units": 0,
33+
},
34+
(16, 1, 32, 32),
35+
(16, 3, 32, 32),
36+
]
37+
38+
TEST_CASE_1 = [ # single channel 2D, batch 16
39+
{
40+
"spatial_dims": 2,
41+
"in_channels": 1,
42+
"out_channels": 3,
43+
"channels": (16, 32, 64),
44+
"strides": (2, 2),
45+
"num_res_units": 1,
46+
},
47+
(16, 1, 32, 32),
48+
(16, 3, 32, 32),
49+
]
50+
51+
TEST_CASE_2 = [ # single channel 3D, batch 16
52+
{
53+
"spatial_dims": 3,
54+
"in_channels": 1,
55+
"out_channels": 3,
56+
"channels": (16, 32, 64),
57+
"strides": (2, 2),
58+
"num_res_units": 1,
59+
},
60+
(16, 1, 32, 24, 48),
61+
(16, 3, 32, 24, 48),
62+
]
63+
64+
TEST_CASE_3 = [ # 4-channel 3D, batch 16
65+
{
66+
"spatial_dims": 3,
67+
"in_channels": 4,
68+
"out_channels": 3,
69+
"channels": (16, 32, 64),
70+
"strides": (2, 2),
71+
"num_res_units": 1,
72+
},
73+
(16, 4, 32, 64, 48),
74+
(16, 3, 32, 64, 48),
75+
]
76+
77+
TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
78+
{
79+
"spatial_dims": 3,
80+
"in_channels": 4,
81+
"out_channels": 3,
82+
"channels": (16, 32, 64),
83+
"strides": (2, 2),
84+
"num_res_units": 1,
85+
"norm": Norm.BATCH,
86+
},
87+
(16, 4, 32, 64, 48),
88+
(16, 3, 32, 64, 48),
89+
]
90+
91+
TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
92+
{
93+
"spatial_dims": 3,
94+
"in_channels": 4,
95+
"out_channels": 3,
96+
"channels": (16, 32, 64),
97+
"strides": (2, 2),
98+
"num_res_units": 1,
99+
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
100+
"adn_ordering": "NA",
101+
},
102+
(16, 4, 32, 64, 48),
103+
(16, 3, 32, 64, 48),
104+
]
105+
106+
TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
107+
{
108+
"spatial_dims": 3,
109+
"in_channels": 4,
110+
"out_channels": 3,
111+
"channels": (16, 32, 64),
112+
"strides": (2, 2),
113+
"num_res_units": 1,
114+
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
115+
},
116+
(16, 4, 32, 64, 48),
117+
(16, 3, 32, 64, 48),
118+
]
119+
120+
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]
121+
122+
123+
class TestCheckpointUNet(unittest.TestCase):
124+
@parameterized.expand(CASES)
125+
def test_shape(self, input_param, input_shape, expected_shape):
126+
"""Validate CheckpointUNet output shapes across configurations.
127+
128+
Args:
129+
input_param: Dictionary of UNet constructor arguments.
130+
input_shape: Tuple specifying input tensor dimensions.
131+
expected_shape: Tuple specifying expected output tensor dimensions.
132+
"""
133+
net = CheckpointUNet(**input_param).to(device)
134+
with eval_mode(net):
135+
result = net.forward(torch.randn(input_shape).to(device))
136+
self.assertEqual(result.shape, expected_shape)
137+
138+
def test_checkpointing_equivalence_eval(self):
139+
"""Confirm eval parity when checkpointing is inactive."""
140+
params = dict(
141+
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
142+
)
143+
144+
x = torch.randn(2, 1, 32, 32, device=device)
145+
146+
torch.manual_seed(42)
147+
net_plain = UNet(**params).to(device)
148+
149+
torch.manual_seed(42)
150+
net_ckpt = CheckpointUNet(**params).to(device)
151+
152+
# Both in eval mode disables checkpointing logic
153+
with eval_mode(net_ckpt), eval_mode(net_plain):
154+
y_ckpt = net_ckpt(x)
155+
y_plain = net_plain(x)
156+
157+
# Check shape equality
158+
self.assertEqual(y_ckpt.shape, y_plain.shape)
159+
160+
# Check numerical equivalence
161+
self.assertTrue(
162+
torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5),
163+
f"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.2e}",
164+
)
165+
166+
def test_checkpointing_activates_training(self):
167+
"""Verify checkpointing recomputes activations during training."""
168+
params = dict(
169+
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
170+
)
171+
172+
net = CheckpointUNet(**params).to(device)
173+
net.train()
174+
175+
x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)
176+
y = net(x)
177+
loss = y.mean()
178+
loss.backward()
179+
180+
# gradient flow check
181+
grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)
182+
self.assertGreater(grad_norm.item(), 0.0)
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

0 commit comments

Comments
 (0)