Skip to content

Commit c2d9be8

Browse files
authored
Merge branch 'dev' into dev
2 parents 341538d + 23c271e commit c2d9be8

File tree

8 files changed

+287
-36
lines changed

8 files changed

+287
-36
lines changed

monai/apps/pathology/transforms/stain/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:
8585
v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)
8686

8787
# a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second
88-
if v_min[0] > v_max[0]:
88+
# Hematoxylin: high blue, lower red (low R/B ratio)
89+
# Eosin: high red, lower blue (high R/B ratio)
90+
eps = np.finfo(np.float32).eps
91+
v_min_rb_ratio = v_min[0] / (v_min[2] + eps)
92+
v_max_rb_ratio = v_max[0] / (v_max[2] + eps)
93+
if v_min_rb_ratio < v_max_rb_ratio:
8994
he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T
9095
else:
9196
he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2021
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2122
from monai.networks.layers.factories import Act, Norm
2223
from monai.networks.layers.simplelayers import SkipConnection
2324

24-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2526

2627

2728
class UNet(nn.Module):
@@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298299
return x
299300

300301

302+
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder-decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
310+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
321+
subblock = ActivationCheckpointWrapper(subblock)
322+
down_path = ActivationCheckpointWrapper(down_path)
323+
up_path = ActivationCheckpointWrapper(up_path)
324+
return super()._get_connection_block(down_path, up_path, subblock)
325+
326+
301327
Unet = UNet

monai/networks/schedulers/ddim.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
117117
)
118118

119119
self.num_inference_steps = num_inference_steps
120-
step_ratio = self.num_train_timesteps // self.num_inference_steps
121-
if self.steps_offset >= step_ratio:
122-
raise ValueError(
123-
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124-
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125-
f" the max train timestep."
126-
)
127-
128-
# creates integer timesteps by multiplying by ratio
129-
# casting to int to avoid issues when num_inference_step is power of 3
130-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
131-
self.timesteps = torch.from_numpy(timesteps).to(device)
120+
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121+
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122+
123+
self.timesteps = (
124+
torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)
125+
.round()
126+
.long()
127+
)
132128
self.timesteps += self.steps_offset
133129

134130
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from __future__ import annotations
3333

34-
import numpy as np
3534
import torch
3635

3736
from monai.utils import StrEnum
@@ -122,11 +121,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122121
)
123122

124123
self.num_inference_steps = num_inference_steps
125-
step_ratio = self.num_train_timesteps // self.num_inference_steps
126-
# creates integer timesteps by multiplying by ratio
127-
# casting to int to avoid issues when num_inference_step is power of 3
128-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
129-
self.timesteps = torch.from_numpy(timesteps).to(device)
124+
self.timesteps = (
125+
torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126+
)
130127

131128
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
132129
"""

tests/apps/pathology/transforms/test_pathology_he_stain.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# input pixels not uniformly filled, leading to two different stains extracted
4949
EXTRACT_STAINS_TEST_CASE_5 = [
5050
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
51-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
51+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
5252
]
5353

5454
# input pixels all transparent and below the beta absorbance threshold
@@ -68,7 +68,7 @@
6868
NORMALIZE_STAINS_TEST_CASE_4 = [
6969
{"target_he": np.full((3, 2), 1)},
7070
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
71-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
71+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
7272
]
7373

7474

@@ -135,7 +135,7 @@ def test_result_value(self, image, expected_data):
135135
[[0.18696113],[0],[0.98236734]] and
136136
[[0.70710677],[0],[0.70710677]] respectively
137137
- the resulting extracted stain should be
138-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
138+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
139139
"""
140140
if image is None:
141141
with self.assertRaises(TypeError):
@@ -206,17 +206,17 @@ def test_result_value(self, arguments, image, expected_data):
206206
207207
For test case 4:
208208
- For this non-uniformly filled image, the stain extracted should be
209-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
209+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
210210
ExtractHEStains class. Solving the linear least squares problem (since
211211
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
212-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
213-
[5.8022, 0, 0, 0, 0, 0]]
212+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
213+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
214214
- Normalizing the concentration matrix, taking the matrix product of the
215215
target stain matrix and the concentration matrix, using the inverse
216216
Beer-Lambert transform to obtain the RGB image from the absorbance
217217
image, and finally converting to uint8, we get that the stain normalized
218-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
219-
[[33, 33, 33], [33, 33, 33]]]
218+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
219+
[[85, 85, 85], [85, 85, 85]]]
220220
"""
221221
if image is None:
222222
with self.assertRaises(TypeError):

tests/apps/pathology/transforms/test_pathology_he_stain_dict.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# input pixels not uniformly filled, leading to two different stains extracted
4343
EXTRACT_STAINS_TEST_CASE_5 = [
4444
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
45-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
45+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
4646
]
4747

4848
# input pixels all transparent and below the beta absorbance threshold
@@ -62,7 +62,7 @@
6262
NORMALIZE_STAINS_TEST_CASE_4 = [
6363
{"target_he": np.full((3, 2), 1)},
6464
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
65-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
65+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
6666
]
6767

6868

@@ -129,7 +129,7 @@ def test_result_value(self, image, expected_data):
129129
[[0.18696113],[0],[0.98236734]] and
130130
[[0.70710677],[0],[0.70710677]] respectively
131131
- the resulting extracted stain should be
132-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
132+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
133133
"""
134134
key = "image"
135135
if image is None:
@@ -200,17 +200,17 @@ def test_result_value(self, arguments, image, expected_data):
200200
201201
For test case 4:
202202
- For this non-uniformly filled image, the stain extracted should be
203-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
203+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
204204
ExtractHEStains class. Solving the linear least squares problem (since
205205
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
206-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
207-
[5.8022, 0, 0, 0, 0, 0]]
206+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
207+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
208208
- Normalizing the concentration matrix, taking the matrix product of the
209209
target stain matrix and the concentration matrix, using the inverse
210210
Beer-Lambert transform to obtain the RGB image from the absorbance
211211
image, and finally converting to uint8, we get that the stain normalized
212-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
213-
[[33, 33, 33], [33, 33, 33]]]
212+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
213+
[[85, 85, 85], [85, 85, 85]]]
214214
"""
215215
key = "image"
216216
if image is None:

0 commit comments

Comments
 (0)