Skip to content

Commit 01a21e0

Browse files
update loadimage
Signed-off-by: Yiheng Wang <[email protected]>
1 parent d3551cc commit 01a21e0

File tree

2 files changed

+53
-66
lines changed

2 files changed

+53
-66
lines changed

monai/data/image_reader.py

Lines changed: 43 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pathlib import Path
2424
from typing import TYPE_CHECKING, Any
2525
import torch
26-
26+
from monai.data.meta_tensor import MetaTensor
2727
import numpy as np
2828
from torch.utils.data._utils.collate import np_str_obj_array_pattern
2929

@@ -1038,13 +1038,22 @@ def _get_array_data(self, img):
10381038
@require_pkg(pkg_name="kvikio")
10391039
class NibabelGPUReader(NibabelReader):
10401040

1041-
def _gds_load(self, file_path):
1042-
file_size = os.path.getsize(file_path)
1041+
def read(self, filename: PathLike, **kwargs):
1042+
"""
1043+
Read image data from specified file or files, it can read a list of images
1044+
and stack them together as multi-channel data in `get_data()`.
1045+
Note that the returned object is Nibabel image object or list of Nibabel image objects.
1046+
1047+
Args:
1048+
data: file name.
1049+
1050+
"""
1051+
file_size = os.path.getsize(filename)
10431052
image = cp.empty(file_size, dtype=cp.uint8)
1044-
with kvikio.CuFile(file_path, "r") as f:
1053+
with kvikio.CuFile(filename, "r") as f:
10451054
f.read(image)
10461055

1047-
if file_path.endswith(".gz"):
1056+
if filename.endswith(".gz"):
10481057
# for compressed data, have to tansfer to CPU to decompress
10491058
# and then transfer back to GPU. It is not efficient compared to .nii file
10501059
# but it's still faster than Nibabel's default reader.
@@ -1056,29 +1065,8 @@ def _gds_load(self, file_path):
10561065

10571066
file_size = len(decompressed_data)
10581067
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
1059-
10601068
return image
10611069

1062-
def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
1063-
"""
1064-
Read image data from specified file or files, it can read a list of images
1065-
and stack them together as multi-channel data in `get_data()`.
1066-
Note that the returned object is Nibabel image object or list of Nibabel image objects.
1067-
1068-
Args:
1069-
data: file name or a list of file names to read.
1070-
1071-
"""
1072-
img_ = []
1073-
1074-
filenames: Sequence[PathLike] = ensure_tuple(data)
1075-
kwargs_ = self.kwargs.copy()
1076-
kwargs_.update(kwargs)
1077-
for name in filenames:
1078-
img = self._gds_load(name)
1079-
img_.append(img) # type: ignore
1080-
return img_ if len(filenames) > 1 else img_[0]
1081-
10821070
def get_data(self, img):
10831071
"""
10841072
Extract data array and metadata from loaded image and return them.
@@ -1088,39 +1076,38 @@ def get_data(self, img):
10881076
and the metadata of the first image is used to present the output metadata.
10891077
10901078
Args:
1091-
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
1079+
img: a Nibabel image object loaded from an image file.
10921080
10931081
"""
1094-
compatible_meta: dict = {}
1095-
img_array = []
1096-
for i in ensure_tuple(img):
1097-
header = self._get_header(i)
1098-
data_offset = header.get_data_offset()
1099-
data_shape = header.get_data_shape()
1100-
data_dtype = header.get_data_dtype()
1101-
affine = header.get_best_affine()
1102-
meta = dict(header)
1103-
meta[MetaKeys.AFFINE] = affine
1104-
meta[MetaKeys.ORIGINAL_AFFINE] = affine
1105-
# TODO: as_closest_canonical
1106-
# TODO: correct_nifti_header_if_necessary
1107-
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
1108-
# TODO: figure out why always RAS for NibabelReader ?
1109-
# meta[MetaKeys.SPACE] = SpaceKeys.RAS
1110-
1111-
data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F")
1112-
# TODO: check channel
1113-
# if self.squeeze_non_spatial_dims:
1114-
img_array.append(data)
1115-
if self.channel_dim is None: # default to "no_channel" or -1
1116-
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
1117-
float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1
1118-
)
1119-
else:
1120-
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
1121-
_copy_compatible_dict(meta, compatible_meta)
11221082

1123-
return self._stack_images(img_array, compatible_meta), compatible_meta
1083+
# TODO: use a formal way for device
1084+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1085+
1086+
header = self._get_header(img)
1087+
data_offset = header.get_data_offset()
1088+
data_shape = header.get_data_shape()
1089+
data_dtype = header.get_data_dtype()
1090+
affine = header.get_best_affine()
1091+
meta = dict(header)
1092+
meta[MetaKeys.AFFINE] = affine
1093+
meta[MetaKeys.ORIGINAL_AFFINE] = affine
1094+
# TODO: as_closest_canonical
1095+
# TODO: correct_nifti_header_if_necessary
1096+
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
1097+
# TODO: figure out why always RAS for NibabelReader ?
1098+
# meta[MetaKeys.SPACE] = SpaceKeys.RAS
1099+
1100+
data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F")
1101+
# TODO: check channel
1102+
# if self.squeeze_non_spatial_dims:
1103+
if self.channel_dim is None: # default to "no_channel" or -1
1104+
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
1105+
float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1
1106+
)
1107+
else:
1108+
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
1109+
1110+
return MetaTensor(data, affine=affine, meta=meta, device=device)
11241111

11251112
def _get_header(self, img):
11261113
"""
@@ -1139,15 +1126,6 @@ def _get_header(self, img):
11391126
pass
11401127
return header
11411128

1142-
def _stack_images(self, image_list: list, meta_dict: dict):
1143-
if len(image_list) <= 1:
1144-
return image_list[0]
1145-
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
1146-
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
1147-
return torch.cat(image_list, axis=channel_dim)
1148-
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
1149-
return torch.stack(image_list, dim=0)
1150-
11511129

11521130
class NumpyReader(ImageReader):
11531131
"""

monai/transforms/io/array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
258258
)
259259
img, err = None, []
260260
if reader is not None:
261+
if isinstance(reader, NibabelGPUReader):
262+
buffer = reader.read(filename)
263+
img = reader.get_data(buffer)
264+
# TODO: check ensure channel first
265+
if self.ensure_channel_first:
266+
img = EnsureChannelFirst()(img)
267+
if self.image_only:
268+
return img
269+
return img, img.meta
270+
261271
img = reader.read(filename) # runtime specified reader
262272
else:
263273
for reader in self.readers[::-1]:
@@ -288,7 +298,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
288298
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
289299
f" The current registered: {self.readers}.\n{msg}"
290300
)
291-
292301
img_array: NdarrayOrTensor
293302
img_array, meta_data = reader.get_data(img)
294303
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]

0 commit comments

Comments
 (0)