Skip to content

Commit 936c40e

Browse files
authored
Merge branch 'dev' into feature/add-non-central-chi-noise
2 parents a7f0862 + c968907 commit 936c40e

File tree

25 files changed

+564
-65
lines changed

25 files changed

+564
-65
lines changed

.github/workflows/codeql-analysis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242

4343
# Initializes the CodeQL tools for scanning.
4444
- name: Initialize CodeQL
45-
uses: github/codeql-action/init@v3
45+
uses: github/codeql-action/init@v4
4646
with:
4747
languages: ${{ matrix.language }}
4848
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -72,4 +72,4 @@ jobs:
7272
BUILD_MONAI=1 ./runtests.sh --build
7373
7474
- name: Perform CodeQL Analysis
75-
uses: github/codeql-action/analyze@v3
75+
uses: github/codeql-action/analyze@v4

.github/workflows/docker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
python setup.py build
3838
cat build/lib/monai/_version.py
3939
- name: Upload version
40-
uses: actions/upload-artifact@v4
40+
uses: actions/upload-artifact@v5
4141
with:
4242
name: _version.py
4343
path: build/lib/monai/_version.py

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666

6767
- if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')
6868
name: Upload artifacts
69-
uses: actions/upload-artifact@v4
69+
uses: actions/upload-artifact@v5
7070
with:
7171
name: dist
7272
path: dist/
@@ -109,7 +109,7 @@ jobs:
109109
python setup.py build
110110
cat build/lib/monai/_version.py
111111
- name: Upload version
112-
uses: actions/upload-artifact@v4
112+
uses: actions/upload-artifact@v5
113113
with:
114114
name: _version.py
115115
path: build/lib/monai/_version.py

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<p align="center">
2-
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" width="50%" alt='project-monai'>
2+
<img src="https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png" width="50%" alt='project-monai'>
33
</p>
44

55
**M**edical **O**pen **N**etwork for **AI**

docs/source/transforms.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Generic Interfaces
3737
.. autoclass:: MultiSampleTrait
3838
:members:
3939

40+
`ReduceTrait`
41+
^^^^^^^^^^^^^^^^^^
42+
.. autoclass:: ReduceTrait
43+
:members:
44+
4045
`Randomizable`
4146
^^^^^^^^^^^^^^
4247
.. autoclass:: Randomizable
@@ -1252,6 +1257,12 @@ Utility
12521257
:members:
12531258
:special-members: __call__
12541259

1260+
`FlattenSequence`
1261+
""""""""""""""""""""""""
1262+
.. autoclass:: FlattenSequence
1263+
:members:
1264+
:special-members: __call__
1265+
12551266
Dictionary Transforms
12561267
---------------------
12571268

@@ -2337,6 +2348,12 @@ Utility (Dict)
23372348
:members:
23382349
:special-members: __call__
23392350

2351+
`FlattenSequenced`
2352+
"""""""""""""""""""""""""
2353+
.. autoclass:: FlattenSequenced
2354+
:members:
2355+
:special-members: __call__
2356+
23402357

23412358
MetaTensor
23422359
^^^^^^^^^^

monai/apps/pathology/transforms/stain/array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:
8585
v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)
8686

8787
# a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second
88-
if v_min[0] > v_max[0]:
88+
# Hematoxylin: high blue, lower red (low R/B ratio)
89+
# Eosin: high red, lower blue (high R/B ratio)
90+
eps = np.finfo(np.float32).eps
91+
v_min_rb_ratio = v_min[0] / (v_min[2] + eps)
92+
v_max_rb_ratio = v_max[0] / (v_max[2] + eps)
93+
if v_min_rb_ratio < v_max_rb_ratio:
8994
he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T
9095
else:
9196
he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T

monai/data/box_utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,10 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828
Returns:
829-
IoU, with size of (N,M) and same data type as ``boxes1``
829+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
830+
floating-point with size ``(N, M)``:
831+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
832+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
830833
831834
"""
832835

@@ -842,16 +845,18 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
842845

843846
inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)
844847

845-
# compute IoU and convert back to original box_dtype
848+
# compute IoU and convert back to original box_dtype or torch.float32
846849
iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M)
850+
if not box_dtype.is_floating_point:
851+
box_dtype = COMPUTE_DTYPE
847852
iou_t = iou_t.to(dtype=box_dtype)
848853

849854
# check if NaN or Inf
850855
if torch.isnan(iou_t).any() or torch.isinf(iou_t).any():
851856
raise ValueError("Box IoU is NaN or Inf.")
852857

853858
# convert tensor back to numpy if needed
854-
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1)
859+
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1, dtype=box_dtype)
855860
return iou
856861

857862

@@ -867,7 +872,10 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
867872
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
868873
869874
Returns:
870-
GIoU, with size of (N,M) and same data type as ``boxes1``
875+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
876+
floating-point with size ``(N, M)``:
877+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
878+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
871879
872880
Reference:
873881
https://giou.stanford.edu/GIoU.pdf
@@ -904,12 +912,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
904912

905913
# GIoU
906914
giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)
915+
if not box_dtype.is_floating_point:
916+
box_dtype = COMPUTE_DTYPE
907917
giou_t = giou_t.to(dtype=box_dtype)
918+
908919
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
909920
raise ValueError("Box GIoU is NaN or Inf.")
910921

911922
# convert tensor back to numpy if needed
912-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
923+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)
913924
return giou
914925

915926

@@ -925,7 +936,10 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
925936
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
926937
927938
Returns:
928-
paired GIoU, with size of (N,) and same data type as ``boxes1``
939+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
940+
floating-point with size ``(N, )``:
941+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
942+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
929943
930944
Reference:
931945
https://giou.stanford.edu/GIoU.pdf
@@ -982,12 +996,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
982996
enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,)
983997

984998
giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) # type: ignore
999+
if not box_dtype.is_floating_point:
1000+
box_dtype = COMPUTE_DTYPE
9851001
giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims)
1002+
9861003
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
9871004
raise ValueError("Box GIoU is NaN or Inf.")
9881005

9891006
# convert tensor back to numpy if needed
990-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
1007+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)
9911008
return giou
9921009

9931010

monai/data/dataset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
pickle_protocol: int = DEFAULT_PROTOCOL,
231231
hash_transform: Callable[..., bytes] | None = None,
232232
reset_ops_id: bool = True,
233+
track_meta: bool = False,
234+
weights_only: bool = True,
233235
) -> None:
234236
"""
235237
Args:
@@ -264,7 +266,17 @@ def __init__(
264266
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
265267
This is useful for skipping the transform instance checks when inverting applied operations
266268
using the cached content and with re-created transform instances.
267-
269+
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
270+
default to `False`. Cannot be used with `weights_only=True`.
271+
weights_only: keyword argument passed to `torch.load` when reading cached files.
272+
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
273+
other safe objects. Setting this to `False` is required for loading `MetaTensor`
274+
objects saved with `track_meta=True`, however this creates the possibility of remote
275+
code execution through `torch.load` so be aware of the security implications of doing so.
276+
277+
Raises:
278+
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
279+
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
268280
"""
269281
super().__init__(data=data, transform=transform)
270282
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
@@ -280,6 +292,13 @@ def __init__(
280292
if hash_transform is not None:
281293
self.set_transform_hash(hash_transform)
282294
self.reset_ops_id = reset_ops_id
295+
if track_meta and weights_only:
296+
raise ValueError(
297+
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
298+
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
299+
)
300+
self.track_meta = track_meta
301+
self.weights_only = weights_only
283302

284303
def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
285304
"""Get hashable transforms, and then hash them. Hashable transforms
@@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):
377396

378397
if hashfile is not None and hashfile.is_file(): # cache hit
379398
try:
380-
return torch.load(hashfile, weights_only=True)
399+
return torch.load(hashfile, weights_only=self.weights_only)
381400
except PermissionError as e:
382401
if sys.platform != "win32":
383402
raise e
@@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
398417
with tempfile.TemporaryDirectory() as tmpdirname:
399418
temp_hash_file = Path(tmpdirname) / hashfile.name
400419
torch.save(
401-
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
420+
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
402421
f=temp_hash_file,
403422
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
404423
pickle_protocol=self.pickle_protocol,
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2021
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2122
from monai.networks.layers.factories import Act, Norm
2223
from monai.networks.layers.simplelayers import SkipConnection
2324

24-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2526

2627

2728
class UNet(nn.Module):
@@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298299
return x
299300

300301

302+
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder-decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
310+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
321+
subblock = ActivationCheckpointWrapper(subblock)
322+
down_path = ActivationCheckpointWrapper(down_path)
323+
up_path = ActivationCheckpointWrapper(up_path)
324+
return super()._get_connection_block(down_path, up_path, subblock)
325+
326+
301327
Unet = UNet

0 commit comments

Comments
 (0)