Skip to content

Commit 196b813

Browse files
authored
Merge branch 'dev' into 8620-modulenotfounderror-no-module-named-onnxscript-in-test-py3x-311-pipeline
Signed-off-by: Rafael Garcia-Dias <[email protected]>
2 parents 701692c + c968907 commit 196b813

File tree

12 files changed

+346
-44
lines changed

12 files changed

+346
-44
lines changed

monai/apps/pathology/transforms/stain/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:
8585
v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)
8686

8787
# a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second
88-
if v_min[0] > v_max[0]:
88+
# Hematoxylin: high blue, lower red (low R/B ratio)
89+
# Eosin: high red, lower blue (high R/B ratio)
90+
eps = np.finfo(np.float32).eps
91+
v_min_rb_ratio = v_min[0] / (v_min[2] + eps)
92+
v_max_rb_ratio = v_max[0] / (v_max[2] + eps)
93+
if v_min_rb_ratio < v_max_rb_ratio:
8994
he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T
9095
else:
9196
he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T

monai/data/dataset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
pickle_protocol: int = DEFAULT_PROTOCOL,
231231
hash_transform: Callable[..., bytes] | None = None,
232232
reset_ops_id: bool = True,
233+
track_meta: bool = False,
234+
weights_only: bool = True,
233235
) -> None:
234236
"""
235237
Args:
@@ -264,7 +266,17 @@ def __init__(
264266
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
265267
This is useful for skipping the transform instance checks when inverting applied operations
266268
using the cached content and with re-created transform instances.
267-
269+
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
270+
default to `False`. Cannot be used with `weights_only=True`.
271+
weights_only: keyword argument passed to `torch.load` when reading cached files.
272+
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
273+
other safe objects. Setting this to `False` is required for loading `MetaTensor`
274+
objects saved with `track_meta=True`, however this creates the possibility of remote
275+
code execution through `torch.load` so be aware of the security implications of doing so.
276+
277+
Raises:
278+
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
279+
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
268280
"""
269281
super().__init__(data=data, transform=transform)
270282
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
@@ -280,6 +292,13 @@ def __init__(
280292
if hash_transform is not None:
281293
self.set_transform_hash(hash_transform)
282294
self.reset_ops_id = reset_ops_id
295+
if track_meta and weights_only:
296+
raise ValueError(
297+
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
298+
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
299+
)
300+
self.track_meta = track_meta
301+
self.weights_only = weights_only
283302

284303
def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
285304
"""Get hashable transforms, and then hash them. Hashable transforms
@@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):
377396

378397
if hashfile is not None and hashfile.is_file(): # cache hit
379398
try:
380-
return torch.load(hashfile, weights_only=True)
399+
return torch.load(hashfile, weights_only=self.weights_only)
381400
except PermissionError as e:
382401
if sys.platform != "win32":
383402
raise e
@@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
398417
with tempfile.TemporaryDirectory() as tmpdirname:
399418
temp_hash_file = Path(tmpdirname) / hashfile.name
400419
torch.save(
401-
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
420+
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
402421
f=temp_hash_file,
403422
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
404423
pickle_protocol=self.pickle_protocol,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2021
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2122
from monai.networks.layers.factories import Act, Norm
2223
from monai.networks.layers.simplelayers import SkipConnection
2324

24-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2526

2627

2728
class UNet(nn.Module):
@@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298299
return x
299300

300301

302+
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder-decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
310+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
321+
subblock = ActivationCheckpointWrapper(subblock)
322+
down_path = ActivationCheckpointWrapper(down_path)
323+
up_path = ActivationCheckpointWrapper(up_path)
324+
return super()._get_connection_block(down_path, up_path, subblock)
325+
326+
301327
Unet = UNet

monai/networks/schedulers/ddim.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
117117
)
118118

119119
self.num_inference_steps = num_inference_steps
120-
step_ratio = self.num_train_timesteps // self.num_inference_steps
121-
if self.steps_offset >= step_ratio:
122-
raise ValueError(
123-
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124-
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125-
f" the max train timestep."
126-
)
127-
128-
# creates integer timesteps by multiplying by ratio
129-
# casting to int to avoid issues when num_inference_step is power of 3
130-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
131-
self.timesteps = torch.from_numpy(timesteps).to(device)
120+
if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:
121+
raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).")
122+
123+
self.timesteps = (
124+
torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)
125+
.round()
126+
.long()
127+
)
132128
self.timesteps += self.steps_offset
133129

134130
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:

monai/networks/schedulers/ddpm.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
from __future__ import annotations
3333

34-
import numpy as np
3534
import torch
3635

3736
from monai.utils import StrEnum
@@ -122,11 +121,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N
122121
)
123122

124123
self.num_inference_steps = num_inference_steps
125-
step_ratio = self.num_train_timesteps // self.num_inference_steps
126-
# creates integer timesteps by multiplying by ratio
127-
# casting to int to avoid issues when num_inference_step is power of 3
128-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
129-
self.timesteps = torch.from_numpy(timesteps).to(device)
124+
self.timesteps = (
125+
torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()
126+
)
130127

131128
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
132129
"""

monai/transforms/croppad/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
144144
_mode = _convert_pt_pad_mode(mode)
145145
img = pad_nd(img, to_pad, mode=_mode, **kwargs)
146146
if do_crop:
147-
img = img[to_crop]
147+
img = img[tuple(to_crop)]
148148
return img
149149

150150

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ h5py
5151
nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
5252
optuna
5353
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
54-
onnx>=1.13.0, <1.19.1
54+
onnx>=1.13.0
5555
onnxscript
5656
onnxruntime; python_version <= '3.10'
5757
typeguard<3 # https://github.com/microsoft/nni/issues/5457

tests/apps/pathology/transforms/test_pathology_he_stain.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
# input pixels not uniformly filled, leading to two different stains extracted
4949
EXTRACT_STAINS_TEST_CASE_5 = [
5050
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
51-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
51+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
5252
]
5353

5454
# input pixels all transparent and below the beta absorbance threshold
@@ -68,7 +68,7 @@
6868
NORMALIZE_STAINS_TEST_CASE_4 = [
6969
{"target_he": np.full((3, 2), 1)},
7070
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
71-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
71+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
7272
]
7373

7474

@@ -135,7 +135,7 @@ def test_result_value(self, image, expected_data):
135135
[[0.18696113],[0],[0.98236734]] and
136136
[[0.70710677],[0],[0.70710677]] respectively
137137
- the resulting extracted stain should be
138-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
138+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
139139
"""
140140
if image is None:
141141
with self.assertRaises(TypeError):
@@ -206,17 +206,17 @@ def test_result_value(self, arguments, image, expected_data):
206206
207207
For test case 4:
208208
- For this non-uniformly filled image, the stain extracted should be
209-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
209+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
210210
ExtractHEStains class. Solving the linear least squares problem (since
211211
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
212-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
213-
[5.8022, 0, 0, 0, 0, 0]]
212+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
213+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
214214
- Normalizing the concentration matrix, taking the matrix product of the
215215
target stain matrix and the concentration matrix, using the inverse
216216
Beer-Lambert transform to obtain the RGB image from the absorbance
217217
image, and finally converting to uint8, we get that the stain normalized
218-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
219-
[[33, 33, 33], [33, 33, 33]]]
218+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
219+
[[85, 85, 85], [85, 85, 85]]]
220220
"""
221221
if image is None:
222222
with self.assertRaises(TypeError):

tests/apps/pathology/transforms/test_pathology_he_stain_dict.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
# input pixels not uniformly filled, leading to two different stains extracted
4343
EXTRACT_STAINS_TEST_CASE_5 = [
4444
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
45-
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
45+
np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),
4646
]
4747

4848
# input pixels all transparent and below the beta absorbance threshold
@@ -62,7 +62,7 @@
6262
NORMALIZE_STAINS_TEST_CASE_4 = [
6363
{"target_he": np.full((3, 2), 1)},
6464
np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),
65-
np.array([[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]]]),
65+
np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),
6666
]
6767

6868

@@ -129,7 +129,7 @@ def test_result_value(self, image, expected_data):
129129
[[0.18696113],[0],[0.98236734]] and
130130
[[0.70710677],[0],[0.70710677]] respectively
131131
- the resulting extracted stain should be
132-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]]
132+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]
133133
"""
134134
key = "image"
135135
if image is None:
@@ -200,17 +200,17 @@ def test_result_value(self, arguments, image, expected_data):
200200
201201
For test case 4:
202202
- For this non-uniformly filled image, the stain extracted should be
203-
[[0.70710677,0.18696113],[0,0],[0.70710677,0.98236734]], as validated for the
203+
[[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the
204204
ExtractHEStains class. Solving the linear least squares problem (since
205205
absorbance matrix = stain matrix * concentration matrix), we obtain the concentration
206-
matrix that should be [[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508],
207-
[5.8022, 0, 0, 0, 0, 0]]
206+
matrix that should be [[5.8022, 0, 0, 0, 0, 0],
207+
[-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]
208208
- Normalizing the concentration matrix, taking the matrix product of the
209209
target stain matrix and the concentration matrix, using the inverse
210210
Beer-Lambert transform to obtain the RGB image from the absorbance
211211
image, and finally converting to uint8, we get that the stain normalized
212-
image should be [[[87, 87, 87], [33, 33, 33]], [[33, 33, 33], [33, 33, 33]],
213-
[[33, 33, 33], [33, 33, 33]]]
212+
image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],
213+
[[85, 85, 85], [85, 85, 85]]]
214214
"""
215215
key = "image"
216216
if image is None:

0 commit comments

Comments
 (0)