Skip to content

Commit 97dc6c6

Browse files
authored
Merge branch 'dev' into fix-compose
2 parents 44066af + df1ba5d commit 97dc6c6

File tree

4 files changed

+54
-7
lines changed

4 files changed

+54
-7
lines changed

monai/bundle/scripts.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
174174

175175

176176
def _get_ngc_bundle_url(model_name: str, version: str) -> str:
177-
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
177+
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files"
178178

179179

180180
def _get_ngc_private_base_url(repo: str) -> str:
@@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
218218
return name
219219

220220

221+
def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:
222+
if not has_requests:
223+
raise ValueError("requests package is required, please install it.")
224+
headers = {} if headers is None else headers
225+
response = requests_get(request_url, headers=headers)
226+
response.raise_for_status()
227+
model_info = json.loads(response.text)
228+
229+
if not isinstance(model_info, dict) or "modelFiles" not in model_info:
230+
raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.")
231+
232+
model_files = model_info["modelFiles"]
233+
return [f["path"] for f in model_files]
234+
235+
221236
def _download_from_ngc(
222237
download_path: Path,
223238
filename: str,
@@ -229,12 +244,12 @@ def _download_from_ngc(
229244
# ensure prefix is contained
230245
filename = _add_ngc_prefix(filename, prefix=prefix)
231246
url = _get_ngc_bundle_url(model_name=filename, version=version)
232-
filepath = download_path / f"{filename}_v{version}.zip"
233247
if remove_prefix:
234248
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
235-
extract_path = download_path / f"{filename}"
236-
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
237-
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
249+
filepath = download_path / filename
250+
filepath.mkdir(parents=True, exist_ok=True)
251+
for file in _get_all_download_files(url):
252+
download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress)
238253

239254

240255
def _download_from_ngc_private(

monai/transforms/intensity/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
821821
mean and std on each channel separately.
822822
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
823823
be the number of image channels if they are not None.
824+
If the input is not of floating point type, it will be converted to float32
824825
825826
Args:
826827
subtrahend: the amount to subtract by (usually the mean).
@@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
907908
if self.divisor is not None and len(self.divisor) != len(img):
908909
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
909910

911+
if not img.dtype.is_floating_point:
912+
img, *_ = convert_data_type(img, dtype=torch.float32)
913+
910914
for i, d in enumerate(img):
911915
img[i] = self._normalize( # type: ignore
912916
d,

tests/test_normalize_intensity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ def test_channel_wise(self, im_type):
108108
normalized = normalizer(input_data)
109109
assert_allclose(normalized, im_type(expected), type_test="tensor")
110110

111+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
112+
def test_channel_wise_int(self, im_type):
113+
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
114+
input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4))
115+
expected = np.array(
116+
[
117+
[
118+
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
119+
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
120+
[0.7242068, 1.0138896, 1.3035723, 1.593255],
121+
],
122+
[
123+
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
124+
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
125+
[0.7242068, 1.0138896, 1.3035723, 1.593255],
126+
],
127+
]
128+
)
129+
normalized = normalizer(input_data)
130+
assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance
131+
111132
@parameterized.expand([[p] for p in TEST_NDARRAYS])
112133
def test_value_errors(self, im_type):
113134
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))

tests/test_zarr_avg_merger.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@
1919
from torch.nn.functional import pad
2020

2121
from monai.inferers import ZarrAvgMerger
22-
from monai.utils import optional_import
22+
from monai.utils import get_package_version, optional_import, version_geq
2323
from tests.utils import assert_allclose
2424

2525
np.seterr(divide="ignore", invalid="ignore")
2626
zarr, has_zarr = optional_import("zarr")
27+
if has_zarr:
28+
if version_geq(get_package_version("zarr"), "3.0.0"):
29+
directory_store = zarr.storage.LocalStore("test.zarr")
30+
else:
31+
directory_store = zarr.storage.DirectoryStore("test.zarr")
32+
else:
33+
directory_store = None
2734
numcodecs, has_numcodecs = optional_import("numcodecs")
2835

2936
TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)
@@ -154,7 +161,7 @@
154161

155162
# explicit directory store
156163
TEST_CASE_10_DIRECTORY_STORE = [
157-
dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")),
164+
dict(merged_shape=TENSOR_4x4.shape, store=directory_store),
158165
[
159166
(TENSOR_4x4[..., :2, :2], (0, 0)),
160167
(TENSOR_4x4[..., :2, 2:], (0, 2)),

0 commit comments

Comments
 (0)