Skip to content

Commit 7eb890f

Browse files
update
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 59eccd4 commit 7eb890f

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

monai/data/image_reader.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,28 +142,21 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
142142
)
143143

144144

145-
def _stack_images(image_list: list, meta_dict: dict):
145+
def _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):
146146
if len(image_list) <= 1:
147147
return image_list[0]
148148
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
149149
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
150+
if to_cupy and has_cp:
151+
return cp.concatenate(image_list, axis=channel_dim)
150152
return np.concatenate(image_list, axis=channel_dim)
151153
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
152154
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
155+
if to_cupy and has_cp:
156+
return cp.stack(image_list, axis=0)
153157
return np.stack(image_list, axis=0)
154158

155159

156-
def _stack_gpu_images(image_list: list, meta_dict: dict):
157-
if len(image_list) <= 1:
158-
return image_list[0]
159-
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
160-
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
161-
return cp.concatenate(image_list, axis=channel_dim)
162-
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
163-
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
164-
return cp.stack(image_list, axis=0)
165-
166-
167160
@require_pkg(pkg_name="itk")
168161
class ITKReader(ImageReader):
169162
"""
@@ -880,12 +873,16 @@ class NibabelReader(ImageReader):
880873
Load NIfTI format images based on Nibabel library.
881874
882875
Args:
883-
as_closest_canonical: if True, load the image as closest to canonical axis format.
884-
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
885876
channel_dim: the channel dimension of the input image, default is None.
886877
this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
887878
if None, `original_channel_dim` will be either `no_channel` or `-1`.
888879
most Nifti files are usually "channel last", no need to specify this argument for them.
880+
as_closest_canonical: if True, load the image as closest to canonical axis format.
881+
squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
882+
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
883+
Default is False. CuPy and Kvikio are required for this option.
884+
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
885+
and the acceleration may not be significant.
889886
kwargs: additional args for `nibabel.load` API. more details about available args:
890887
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
891888
@@ -896,15 +893,22 @@ def __init__(
896893
channel_dim: str | int | None = None,
897894
as_closest_canonical: bool = False,
898895
squeeze_non_spatial_dims: bool = False,
899-
gpu_load: bool = False,
896+
to_gpu: bool = False,
900897
**kwargs,
901898
):
902899
super().__init__()
903900
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
904901
self.as_closest_canonical = as_closest_canonical
905902
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
906-
# TODO: add warning if not have required libs
907-
self.gpu_load = gpu_load
903+
if to_gpu is True:
904+
if not has_cp:
905+
warnings.warn("CuPy is not installed, fall back to use cpu load.")
906+
to_gpu = False
907+
if not has_kvikio:
908+
warnings.warn("Kvikio is not installed, fall back to use cpu load.")
909+
to_gpu = False
910+
911+
self.to_gpu = to_gpu
908912
self.kwargs = kwargs
909913

910914
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
@@ -982,8 +986,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
982986
else:
983987
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
984988
_copy_compatible_dict(header, compatible_meta)
985-
if self.gpu_load:
986-
return _stack_gpu_images(img_array, compatible_meta), compatible_meta
989+
if self.to_gpu:
990+
return _stack_images(img_array, compatible_meta, to_cupy=True), compatible_meta
987991
return _stack_images(img_array, compatible_meta), compatible_meta
988992

989993
def _get_meta_dict(self, img) -> dict:
@@ -1047,22 +1051,18 @@ def _get_array_data(self, img, filename):
10471051
if self.gpu_load:
10481052
file_size = os.path.getsize(filename)
10491053
image = cp.empty(file_size, dtype=cp.uint8)
1050-
# suggestion from Ming: more tests, diff size
1051-
# cucim + nifti
10521054
with kvikio.CuFile(filename, "r") as f:
10531055
f.read(image)
10541056
if filename.endswith(".gz"):
10551057
# for compressed data, have to tansfer to CPU to decompress
10561058
# and then transfer back to GPU. It is not efficient compared to .nii file
10571059
# but it's still faster than Nibabel's default reader.
1058-
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1059-
# since it's waste times especially in training
10601060
compressed_data = cp.asnumpy(image)
10611061
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
10621062
decompressed_data = gz_file.read()
10631063

10641064
file_size = len(decompressed_data)
1065-
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
1065+
image = cp.frombuffer(decompressed_data, dtype=cp.uint8)
10661066
data_shape = img.shape
10671067
data_offset = img.dataobj.offset
10681068
data_dtype = img.dataobj.dtype

0 commit comments

Comments
 (0)