Skip to content

Commit ca6796b

Browse files
Merge branch 'Project-MONAI:dev' into 4980-get-wsi-at-mpp
2 parents 349c011 + 9eb0a8c commit ca6796b

File tree

11 files changed

+210
-35
lines changed

11 files changed

+210
-35
lines changed

monai/apps/utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import logging
1717
import os
18+
import re
1819
import shutil
1920
import sys
2021
import tarfile
@@ -30,7 +31,9 @@
3031
from monai.config.type_definitions import PathLike
3132
from monai.utils import look_up_option, min_version, optional_import
3233

34+
requests, has_requests = optional_import("requests")
3335
gdown, has_gdown = optional_import("gdown", "4.7.3")
36+
BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup")
3437

3538
if TYPE_CHECKING:
3639
from tqdm import tqdm
@@ -298,6 +301,29 @@ def extractall(
298301
)
299302

300303

304+
def get_filename_from_url(data_url: str) -> str:
305+
"""
306+
Get the filename from the URL link.
307+
"""
308+
try:
309+
response = requests.head(data_url, allow_redirects=True)
310+
content_disposition = response.headers.get("Content-Disposition")
311+
if content_disposition:
312+
filename = re.findall('filename="?([^";]+)"?', content_disposition)
313+
if filename:
314+
return str(filename[0])
315+
if "drive.google.com" in data_url:
316+
response = requests.get(data_url)
317+
if "text/html" in response.headers.get("Content-Type", ""):
318+
soup = BeautifulSoup(response.text, "html.parser")
319+
filename_div = soup.find("span", {"class": "uc-name-size"})
320+
if filename_div:
321+
return str(filename_div.find("a").text)
322+
return _basename(data_url)
323+
except Exception as e:
324+
raise Exception(f"Error processing URL: {e}") from e
325+
326+
301327
def download_and_extract(
302328
url: str,
303329
filepath: PathLike = "",
@@ -327,7 +353,18 @@ def download_and_extract(
327353
be False.
328354
progress: whether to display progress bar.
329355
"""
356+
url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes)
357+
filepath_ext = "".join(Path(_basename(filepath)).suffixes)
358+
if filepath not in ["", "."]:
359+
if filepath_ext == "":
360+
new_filepath = Path(filepath).with_suffix(url_filename_ext)
361+
logger.warning(
362+
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
363+
)
364+
filepath = new_filepath
365+
if filepath_ext and filepath_ext != url_filename_ext:
366+
raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}")
330367
with tempfile.TemporaryDirectory() as tmp_dir:
331-
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
368+
filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()
332369
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
333370
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)

monai/data/image_reader.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from __future__ import annotations
1313

1414
import glob
15+
import gzip
16+
import io
1517
import os
1618
import re
19+
import tempfile
1720
import warnings
1821
from abc import ABC, abstractmethod
1922
from collections.abc import Callable, Iterable, Iterator, Sequence
@@ -51,6 +54,9 @@
5154
pydicom, has_pydicom = optional_import("pydicom")
5255
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
5356

57+
cp, has_cp = optional_import("cupy")
58+
kvikio, has_kvikio = optional_import("kvikio")
59+
5460
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
5561

5662

@@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
137143
)
138144

139145

140-
def _stack_images(image_list: list, meta_dict: dict):
146+
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
141147
if len(image_list) <= 1:
142148
return image_list[0]
143149
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
144150
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
151+
if to_cupy and has_cp:
152+
return cp.concatenate(image_list, axis=channel_dim)
145153
return np.concatenate(image_list, axis=channel_dim)
146154
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
147155
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
156+
if to_cupy and has_cp:
157+
return cp.stack(image_list, axis=0)
148158
return np.stack(image_list, axis=0)
149159

150160

@@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
864874
Load NIfTI format images based on Nibabel library.
865875
866876
Args:
867-
as_closest_canonical: if True, load the image as closest to canonical axis format.
868-
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
869877
channel_dim: the channel dimension of the input image, default is None.
870878
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
871879
if None, `original_channel_dim` will be either `no_channel` or `-1`.
872880
most Nifti files are usually "channel last", no need to specify this argument for them.
881+
as_closest_canonical: if True, load the image as closest to canonical axis format.
882+
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
883+
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
884+
Default is False. CuPy and Kvikio are required for this option.
885+
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886+
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887+
In practical use, it's recommended to add a warm up call before the actual loading.
888+
A related tutorial will be prepared in the future, and the document will be updated accordingly.
873889
kwargs: additional args for `nibabel.load` API. more details about available args:
874890
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
875891
@@ -880,14 +896,42 @@ def __init__(
880896
channel_dim: str | int | None = None,
881897
as_closest_canonical: bool = False,
882898
squeeze_non_spatial_dims: bool = False,
899+
to_gpu: bool = False,
883900
**kwargs,
884901
):
885902
super().__init__()
886903
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
887904
self.as_closest_canonical = as_closest_canonical
888905
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
906+
if to_gpu and (not has_cp or not has_kvikio):
907+
warnings.warn(
908+
"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
909+
)
910+
to_gpu = False
911+
912+
if to_gpu:
913+
self.warmup_kvikio()
914+
915+
self.to_gpu = to_gpu
889916
self.kwargs = kwargs
890917

918+
def warmup_kvikio(self):
919+
"""
920+
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
921+
This can accelerate the data loading process when `to_gpu` is set to True.
922+
"""
923+
if has_cp and has_kvikio:
924+
a = cp.arange(100)
925+
with tempfile.NamedTemporaryFile() as tmp_file:
926+
tmp_file_name = tmp_file.name
927+
f = kvikio.CuFile(tmp_file_name, "w")
928+
f.write(a)
929+
f.close()
930+
931+
b = cp.empty_like(a)
932+
f = kvikio.CuFile(tmp_file_name, "r")
933+
f.read(b)
934+
891935
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
892936
"""
893937
Verify whether the specified file or files format is supported by Nibabel reader.
@@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
916960
img_: list[Nifti1Image] = []
917961

918962
filenames: Sequence[PathLike] = ensure_tuple(data)
963+
self.filenames = filenames
919964
kwargs_ = self.kwargs.copy()
920965
kwargs_.update(kwargs)
921966
for name in filenames:
@@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
936981
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
937982
938983
"""
984+
# TODO: the actual type is list[np.ndarray | cp.ndarray]
985+
# should figure out how to define correct types without having cupy not found error
986+
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
939987
img_array: list[np.ndarray] = []
940988
compatible_meta: dict = {}
941989

942-
for i in ensure_tuple(img):
990+
for i, filename in zip(ensure_tuple(img), self.filenames):
943991
header = self._get_meta_dict(i)
944992
header[MetaKeys.AFFINE] = self._get_affine(i)
945993
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
@@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
949997
header[MetaKeys.AFFINE] = self._get_affine(i)
950998
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
951999
header[MetaKeys.SPACE] = SpaceKeys.RAS
952-
data = self._get_array_data(i)
1000+
data = self._get_array_data(i, filename)
9531001
if self.squeeze_non_spatial_dims:
9541002
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
9551003
if data.shape[d - 1] == 1:
@@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
9631011
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
9641012
_copy_compatible_dict(header, compatible_meta)
9651013

966-
return _stack_images(img_array, compatible_meta), compatible_meta
1014+
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta
9671015

9681016
def _get_meta_dict(self, img) -> dict:
9691017
"""
@@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
10151063
spatial_rank = max(min(ndim, 3), 1)
10161064
return np.asarray(size[:spatial_rank])
10171065

1018-
def _get_array_data(self, img):
1066+
def _get_array_data(self, img, filename):
10191067
"""
10201068
Get the raw array data of the image, converted to Numpy array.
10211069
10221070
Args:
10231071
img: a Nibabel image object loaded from an image file.
1024-
1025-
"""
1072+
filename: file name of the image.
1073+
1074+
"""
1075+
if self.to_gpu:
1076+
file_size = os.path.getsize(filename)
1077+
image = cp.empty(file_size, dtype=cp.uint8)
1078+
with kvikio.CuFile(filename, "r") as f:
1079+
f.read(image)
1080+
if filename.endswith(".nii.gz"):
1081+
# for compressed data, have to tansfer to CPU to decompress
1082+
# and then transfer back to GPU. It is not efficient compared to .nii file
1083+
# and may be slower than CPU loading in some cases.
1084+
warnings.warn("Loading compressed NIfTI file into GPU may not be efficient.")
1085+
compressed_data = cp.asnumpy(image)
1086+
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
1087+
decompressed_data = gz_file.read()
1088+
1089+
image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
1090+
data_shape = img.shape
1091+
data_offset = img.dataobj.offset
1092+
data_dtype = img.dataobj.dtype
1093+
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
10261094
return np.asanyarray(img.dataobj, order="C")
10271095

10281096

monai/data/meta_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ def ensure_torch_and_prune_meta(
553553
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
554554
"""
555555
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
556-
557556
# if not tracking metadata, return `torch.Tensor`
558557
if not isinstance(img, MetaTensor):
559558
return img

monai/engines/workflow.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from __future__ import annotations
1313

1414
import warnings
15-
from collections.abc import Callable, Iterable, Sequence
15+
from collections.abc import Callable, Iterable, Sequence, Sized
1616
from typing import TYPE_CHECKING, Any
1717

1818
import torch
@@ -121,24 +121,24 @@ def __init__(
121121
to_kwargs: dict | None = None,
122122
amp_kwargs: dict | None = None,
123123
) -> None:
124-
if iteration_update is not None:
125-
super().__init__(iteration_update)
126-
else:
127-
super().__init__(self._iteration)
124+
super().__init__(self._iteration if iteration_update is None else iteration_update)
128125

129126
if isinstance(data_loader, DataLoader):
130-
sampler = data_loader.__dict__["sampler"]
127+
sampler = getattr(data_loader, "sampler", None)
128+
129+
# set the epoch value for DistributedSampler objects when an epoch starts
131130
if isinstance(sampler, DistributedSampler):
132131

133132
@self.on(Events.EPOCH_STARTED)
134133
def set_sampler_epoch(engine: Engine) -> None:
135134
sampler.set_epoch(engine.state.epoch)
136135

137-
if epoch_length is None:
136+
# if the epoch_length isn't given, attempt to get it from the length of the data loader
137+
if epoch_length is None and isinstance(data_loader, Sized):
138+
try:
138139
epoch_length = len(data_loader)
139-
else:
140-
if epoch_length is None:
141-
raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
140+
except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type
141+
pass # deliberately leave epoch_length as None
142142

143143
# set all sharable data for the workflow based on Ignite engine.state
144144
self.state: Any = State(
@@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None:
147147
iteration=0,
148148
epoch=0,
149149
max_epochs=max_epochs,
150-
epoch_length=epoch_length,
150+
epoch_length=epoch_length, # None when the dataset is iterable and so has no length
151151
output=None,
152152
batch=None,
153153
metrics={},

monai/networks/blocks/pos_embed_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def build_sincos_position_embedding(
5656
grid_h = torch.arange(h, dtype=torch.float32)
5757
grid_w = torch.arange(w, dtype=torch.float32)
5858

59-
grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij")
59+
grid_h, grid_w = torch.meshgrid(grid_h, grid_w)
6060

6161
if embed_dim % 4 != 0:
6262
raise AssertionError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
@@ -75,7 +75,7 @@ def build_sincos_position_embedding(
7575
grid_w = torch.arange(w, dtype=torch.float32)
7676
grid_d = torch.arange(d, dtype=torch.float32)
7777

78-
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij")
78+
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
7979

8080
if embed_dim % 6 != 0:
8181
raise AssertionError("Embed dimension must be divisible by 6 for 3D sin-cos position embedding")

monai/transforms/io/array.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
286286
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
287287
f" The current registered: {self.readers}.\n{msg}"
288288
)
289-
290289
img_array: NdarrayOrTensor
291290
img_array, meta_data = reader.get_data(img)
292291
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]

setup.cfg

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ all =
6161
tqdm>=4.47.0
6262
lmdb
6363
psutil
64-
cucim-cu12; python_version >= '3.9' and python_version <= '3.10'
64+
cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10'
6565
openslide-python
66-
tifffile
67-
imagecodecs
66+
tifffile; platform_system == "Linux" or platform_system == "Darwin"
67+
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
6868
pandas
6969
einops
7070
transformers>=4.36.0, <4.41.0; python_version <= '3.10'
@@ -78,7 +78,7 @@ all =
7878
pynrrd
7979
pydicom
8080
h5py
81-
nni
81+
nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
8282
optuna
8383
onnx>=1.13.0
8484
onnxruntime; python_version <= '3.10'
@@ -116,13 +116,13 @@ lmdb =
116116
psutil =
117117
psutil
118118
cucim =
119-
cucim-cu12
119+
cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10'
120120
openslide =
121121
openslide-python
122122
tifffile =
123-
tifffile
123+
tifffile; platform_system == "Linux" or platform_system == "Darwin"
124124
imagecodecs =
125-
imagecodecs
125+
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
126126
pandas =
127127
pandas
128128
einops =
@@ -152,7 +152,7 @@ pydicom =
152152
h5py =
153153
h5py
154154
nni =
155-
nni
155+
nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
156156
optuna =
157157
optuna
158158
onnx =

tests/test_download_and_extract.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from parameterized import parameterized
2121

2222
from monai.apps import download_and_extract, download_url, extractall
23-
from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config
23+
from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config
2424

2525

26+
@SkipIfNoModule("requests")
2627
class TestDownloadAndExtract(unittest.TestCase):
2728

2829
@skip_if_quick

0 commit comments

Comments
 (0)