Skip to content

Commit 80b8757

Browse files
authored
Merge branch 'dev' into bump-torch-minimum
Signed-off-by: YunLiu <[email protected]>
2 parents cfac884 + 8dcb9dc commit 80b8757

File tree

13 files changed

+339
-98
lines changed

13 files changed

+339
-98
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666
)$
6767
6868
- repo: https://github.com/hadialqattan/pycln
69-
rev: v2.4.0
69+
rev: v2.5.0
7070
hooks:
7171
- id: pycln
7272
args: [--config=pyproject.toml]

monai/bundle/scripts.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
174174

175175

176176
def _get_ngc_bundle_url(model_name: str, version: str) -> str:
177-
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
177+
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"
178178

179179

180180
def _get_ngc_private_base_url(repo: str) -> str:
@@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
218218
return name
219219

220220

221+
def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:
222+
if not has_requests:
223+
raise ValueError("requests package is required, please install it.")
224+
headers = {} if headers is None else headers
225+
response = requests_get(request_url, headers=headers)
226+
response.raise_for_status()
227+
model_info = json.loads(response.text)
228+
229+
if not isinstance(model_info, dict) or "modelFiles" not in model_info:
230+
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")
231+
232+
model_files = model_info["modelFiles"]
233+
return [f["path"] for f in model_files]
234+
235+
221236
def _download_from_ngc(
222237
download_path: Path,
223238
filename: str,
@@ -229,12 +244,12 @@ def _download_from_ngc(
229244
# ensure prefix is contained
230245
filename = _add_ngc_prefix(filename, prefix=prefix)
231246
url = _get_ngc_bundle_url(model_name=filename, version=version)
232-
filepath = download_path / f"{filename}_v{version}.zip"
233247
if remove_prefix:
234248
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
235-
extract_path = download_path / f"{filename}"
236-
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
237-
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
249+
filepath = download_path / filename
250+
filepath.mkdir(parents=True, exist_ok=True)
251+
for file in _get_all_download_files(url):
252+
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)
238253

239254

240255
def _download_from_ngc_private(

monai/data/image_reader.py

Lines changed: 170 additions & 49 deletions
Large diffs are not rendered by default.

monai/data/meta_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,3 +607,8 @@ def print_verbose(self) -> None:
607607
print(self)
608608
if self.meta is not None:
609609
print(self.meta.__repr__())
610+
611+
612+
# needed in later versions of Pytorch to indicate the class is safe for serialisation
613+
if hasattr(torch.serialization, "add_safe_globals"):
614+
torch.serialization.add_safe_globals([MetaTensor])

monai/inferers/merger.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Sequence
1717
from contextlib import nullcontext
18+
from tempfile import TemporaryDirectory
1819
from typing import TYPE_CHECKING, Any
1920

2021
import numpy as np
2122
import torch
2223

23-
from monai.utils import ensure_tuple_size, optional_import, require_pkg
24+
from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq
2425

2526
if TYPE_CHECKING:
2627
import zarr
@@ -233,7 +234,7 @@ def __init__(
233234
store: zarr.storage.Store | str = "merged.zarr",
234235
value_store: zarr.storage.Store | str | None = None,
235236
count_store: zarr.storage.Store | str | None = None,
236-
compressor: str = "default",
237+
compressor: str | None = None,
237238
value_compressor: str | None = None,
238239
count_compressor: str | None = None,
239240
chunks: Sequence[int] | bool = True,
@@ -246,8 +247,22 @@ def __init__(
246247
self.value_dtype = value_dtype
247248
self.count_dtype = count_dtype
248249
self.store = store
249-
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
250-
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
250+
self.tmpdir: TemporaryDirectory | None
251+
if version_geq(get_package_version("zarr"), "3.0.0"):
252+
if value_store is None:
253+
self.tmpdir = TemporaryDirectory()
254+
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
255+
else:
256+
self.value_store = value_store
257+
if count_store is None:
258+
self.tmpdir = TemporaryDirectory()
259+
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
260+
else:
261+
self.count_store = count_store
262+
else:
263+
self.tmpdir = None
264+
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
265+
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
251266
self.chunks = chunks
252267
self.compressor = compressor
253268
self.value_compressor = value_compressor

monai/transforms/intensity/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
821821
mean and std on each channel separately.
822822
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
823823
be the number of image channels if they are not None.
824+
If the input is not of floating point type, it will be converted to float32
824825
825826
Args:
826827
subtrahend: the amount to subtract by (usually the mean).
@@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
907908
if self.divisor is not None and len(self.divisor) != len(img):
908909
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
909910

911+
if not img.dtype.is_floating_point:
912+
img, *_ = convert_data_type(img, dtype=torch.float32)
913+
910914
for i, d in enumerate(img):
911915
img[i] = self._normalize( # type: ignore
912916
d,

monai/utils/jupyter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def plot_engine_status(
234234

235235

236236
def _get_loss_from_output(
237-
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
238238
) -> torch.Tensor:
239239
"""Returns a single value from the network output, which is a dict or tensor."""
240240

monai/visualize/img2tensorboard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def _image3_animated_gif(
6565
img_str = b""
6666
for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
6767
img_str += b_data
68-
img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00"
68+
img_str += b"\x21\xff\x0b\x4e\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2e\x30\x03\x01\x00\x00\x00"
6969
for i in ims:
7070
for b_data in PIL.GifImagePlugin.getdata(i):
7171
img_str += b_data
72-
img_str += b"\x3B"
72+
img_str += b"\x3b"
7373

7474
summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
7575
summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pep8-naming
1818
pycodestyle
1919
pyflakes
2020
black>=22.12
21-
isort>=5.1
21+
isort>=5.1, <6.0
2222
ruff
2323
pytype>=2020.6.1; platform_system != "Windows"
2424
types-setuptools

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
torch>=1.13.1
1+
torch>=1.13.1,<2.6
22
numpy>=1.24,<2.0

0 commit comments

Comments
 (0)