Skip to content

Commit f453158

Browse files
reformat to add gpu load support on nibabelreader
Signed-off-by: Yiheng Wang <[email protected]>
1 parent da41742 commit f453158

File tree

4 files changed

+59
-130
lines changed

4 files changed

+59
-130
lines changed

monai/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .folder_layout import FolderLayout, FolderLayoutBase
5151
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd
5252
from .image_dataset import ImageDataset
53-
from .image_reader import ImageReader, ITKReader, NibabelReader, NibabelGPUReader, NrrdReader, NumpyReader, PILReader, PydicomReader
53+
from .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader
5454
from .image_writer import (
5555
SUPPORTED_WRITERS,
5656
ImageWriter,

monai/data/image_reader.py

Lines changed: 44 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
cp, has_cp = optional_import("cupy")
5959
kvikio, has_kvikio = optional_import("kvikio")
6060

61-
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
61+
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
6262

6363

6464
class ImageReader(ABC):
@@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict):
155155
return np.stack(image_list, axis=0)
156156

157157

158+
def _stack_gpu_images(image_list: list, meta_dict: dict):
159+
if len(image_list) <= 1:
160+
return image_list[0]
161+
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
162+
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
163+
return cp.concatenate(image_list, axis=channel_dim)
164+
# stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
165+
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
166+
return cp.stack(image_list, axis=0)
167+
168+
158169
@require_pkg(pkg_name="itk")
159170
class ITKReader(ImageReader):
160171
"""
@@ -887,12 +898,15 @@ def __init__(
887898
channel_dim: str | int | None = None,
888899
as_closest_canonical: bool = False,
889900
squeeze_non_spatial_dims: bool = False,
901+
gpu_load: bool = False,
890902
**kwargs,
891903
):
892904
super().__init__()
893905
self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim
894906
self.as_closest_canonical = as_closest_canonical
895907
self.squeeze_non_spatial_dims = squeeze_non_spatial_dims
908+
# TODO: add warning if not have required libs
909+
self.gpu_load = gpu_load
896910
self.kwargs = kwargs
897911

898912
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
@@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
923937
img_: list[Nifti1Image] = []
924938

925939
filenames: Sequence[PathLike] = ensure_tuple(data)
940+
self.filenames = filenames
926941
kwargs_ = self.kwargs.copy()
927942
kwargs_.update(kwargs)
928943
for name in filenames:
@@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
946961
img_array: list[np.ndarray] = []
947962
compatible_meta: dict = {}
948963

949-
for i in ensure_tuple(img):
964+
for i, filename in zip(ensure_tuple(img), self.filenames):
950965
header = self._get_meta_dict(i)
951966
header[MetaKeys.AFFINE] = self._get_affine(i)
952967
header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)
@@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
956971
header[MetaKeys.AFFINE] = self._get_affine(i)
957972
header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)
958973
header[MetaKeys.SPACE] = SpaceKeys.RAS
959-
data = self._get_array_data(i)
974+
data = self._get_array_data(i, filename)
960975
if self.squeeze_non_spatial_dims:
961976
for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):
962977
if data.shape[d - 1] == 1:
@@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
969984
else:
970985
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
971986
_copy_compatible_dict(header, compatible_meta)
972-
987+
if self.gpu_load:
988+
return _stack_gpu_images(img_array, compatible_meta), compatible_meta
973989
return _stack_images(img_array, compatible_meta), compatible_meta
974990

975991
def _get_meta_dict(self, img) -> dict:
@@ -1022,111 +1038,40 @@ def _get_spatial_shape(self, img):
10221038
spatial_rank = max(min(ndim, 3), 1)
10231039
return np.asarray(size[:spatial_rank])
10241040

1025-
def _get_array_data(self, img):
1041+
def _get_array_data(self, img, filename):
10261042
"""
10271043
Get the raw array data of the image, converted to Numpy array.
10281044
10291045
Args:
10301046
img: a Nibabel image object loaded from an image file.
10311047
10321048
"""
1049+
if self.gpu_load:
1050+
file_size = os.path.getsize(filename)
1051+
image = cp.empty(file_size, dtype=cp.uint8)
1052+
# suggestion from Ming: more tests, diff size
1053+
# cucim + nifti
1054+
with kvikio.CuFile(filename, "r") as f:
1055+
f.read(image)
1056+
if filename.endswith(".gz"):
1057+
# for compressed data, have to tansfer to CPU to decompress
1058+
# and then transfer back to GPU. It is not efficient compared to .nii file
1059+
# but it's still faster than Nibabel's default reader.
1060+
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1061+
# since it's waste times especially in training
1062+
compressed_data = cp.asnumpy(image)
1063+
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
1064+
decompressed_data = gz_file.read()
1065+
1066+
file_size = len(decompressed_data)
1067+
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
1068+
data_shape = img.shape
1069+
data_offset = img.dataobj.offset
1070+
data_dtype = img.dataobj.dtype
1071+
return image[data_offset:].view(data_dtype).reshape(data_shape, order="F")
10331072
return np.asanyarray(img.dataobj, order="C")
10341073

10351074

1036-
@require_pkg(pkg_name="nibabel")
1037-
@require_pkg(pkg_name="cupy")
1038-
@require_pkg(pkg_name="kvikio")
1039-
class NibabelGPUReader(NibabelReader):
1040-
1041-
def read(self, filename: PathLike, **kwargs):
1042-
"""
1043-
Read image data from specified file or files, it can read a list of images
1044-
and stack them together as multi-channel data in `get_data()`.
1045-
Note that the returned object is Nibabel image object or list of Nibabel image objects.
1046-
1047-
Args:
1048-
data: file name.
1049-
1050-
"""
1051-
file_size = os.path.getsize(filename)
1052-
image = cp.empty(file_size, dtype=cp.uint8)
1053-
with kvikio.CuFile(filename, "r") as f:
1054-
f.read(image)
1055-
1056-
if filename.endswith(".gz"):
1057-
# for compressed data, have to tansfer to CPU to decompress
1058-
# and then transfer back to GPU. It is not efficient compared to .nii file
1059-
# but it's still faster than Nibabel's default reader.
1060-
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1061-
# since it's waste times especially in training
1062-
compressed_data = cp.asnumpy(image)
1063-
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
1064-
decompressed_data = gz_file.read()
1065-
1066-
file_size = len(decompressed_data)
1067-
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
1068-
return image
1069-
1070-
def get_data(self, img):
1071-
"""
1072-
Extract data array and metadata from loaded image and return them.
1073-
This function returns two objects, first is numpy array of image data, second is dict of metadata.
1074-
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
1075-
When loading a list of files, they are stacked together at a new dimension as the first dimension,
1076-
and the metadata of the first image is used to present the output metadata.
1077-
1078-
Args:
1079-
img: a Nibabel image object loaded from an image file.
1080-
1081-
"""
1082-
1083-
# TODO: use a formal way for device
1084-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1085-
1086-
header = self._get_header(img)
1087-
data_offset = header.get_data_offset()
1088-
data_shape = header.get_data_shape()
1089-
data_dtype = header.get_data_dtype()
1090-
affine = header.get_best_affine()
1091-
meta = {}
1092-
meta[MetaKeys.AFFINE] = affine
1093-
meta[MetaKeys.ORIGINAL_AFFINE] = affine
1094-
# TODO: as_closest_canonical
1095-
# TODO: correct_nifti_header_if_necessary
1096-
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
1097-
# TODO: figure out why always RAS for NibabelReader ?
1098-
# meta[MetaKeys.SPACE] = SpaceKeys.RAS
1099-
1100-
data = img[data_offset:].view(data_dtype).reshape(data_shape, order="F")
1101-
# TODO: check channel
1102-
# if self.squeeze_non_spatial_dims:
1103-
if self.channel_dim is None: # default to "no_channel" or -1
1104-
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
1105-
float("nan") if len(data.shape) == len(meta[MetaKeys.SPATIAL_SHAPE]) else -1
1106-
)
1107-
else:
1108-
meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
1109-
1110-
return MetaTensor(data, affine=affine, meta=meta, device=device)
1111-
1112-
def _get_header(self, img):
1113-
"""
1114-
Get the all the metadata of the image and convert to dict type.
1115-
1116-
Args:
1117-
img: a Nibabel image object loaded from an image file.
1118-
1119-
"""
1120-
header_bytes = cp.asnumpy(img[:348])
1121-
header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes))
1122-
# swap to little endian as PyTorch doesn't support big endian
1123-
try:
1124-
header = header.as_byteswapped("<")
1125-
except ValueError:
1126-
pass
1127-
return header
1128-
1129-
11301075
class NumpyReader(ImageReader):
11311076
"""
11321077
Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.

monai/data/meta_tensor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,12 @@ def clone(self, **kwargs):
532532

533533
@staticmethod
534534
def ensure_torch_and_prune_meta(
535-
im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = "."
535+
im: NdarrayTensor,
536+
meta: dict | None,
537+
simple_keys: bool = False,
538+
pattern: str | None = None,
539+
sep: str = ".",
540+
device: None | str | torch.device = None,
536541
):
537542
"""
538543
Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,
@@ -547,13 +552,13 @@ def ensure_torch_and_prune_meta(
547552
sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).
548553
default is ".", see also :py:class:`monai.transforms.DeleteItemsd`.
549554
e.g. ``pattern=".*_code$", sep=" "`` removes any meta keys that ends with ``"_code"``.
555+
device: target device to put the Tensor data.
550556
551557
Returns:
552558
By default, a `MetaTensor` is returned.
553559
However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.
554560
"""
555-
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None) # potentially ascontiguousarray
556-
561+
img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None, device=device) # potentially ascontiguousarray
557562
# if not tracking metadata, return `torch.Tensor`
558563
if not isinstance(img, MetaTensor):
559564
return img
@@ -565,7 +570,7 @@ def ensure_torch_and_prune_meta(
565570
if simple_keys:
566571
# ensure affine is of type `torch.Tensor`
567572
if MetaKeys.AFFINE in meta:
568-
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE]) # bc-breaking
573+
meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE], device=device) # bc-breaking
569574
remove_extra_metadata(meta) # bc-breaking
570575

571576
if pattern is not None:

monai/transforms/io/array.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ImageReader,
3636
ITKReader,
3737
NibabelReader,
38-
NibabelGPUReader,
3938
NrrdReader,
4039
NumpyReader,
4140
PILReader,
@@ -140,6 +139,7 @@ def __init__(
140139
prune_meta_pattern: str | None = None,
141140
prune_meta_sep: str = ".",
142141
expanduser: bool = True,
142+
device: None | str | torch.device = None,
143143
*args,
144144
**kwargs,
145145
) -> None:
@@ -164,6 +164,7 @@ def __init__(
164164
e.g. ``prune_meta_pattern=".*_code$", prune_meta_sep=" "`` removes meta keys that ends with ``"_code"``.
165165
expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.
166166
args: additional parameters for reader if providing a reader name.
167+
device: target device to put the loaded image.
167168
kwargs: additional parameters for reader if providing a reader name.
168169
169170
Note:
@@ -185,6 +186,7 @@ def __init__(
185186
self.pattern = prune_meta_pattern
186187
self.sep = prune_meta_sep
187188
self.expanduser = expanduser
189+
self.device = device
188190

189191
self.readers: list[ImageReader] = []
190192
for r in SUPPORTED_READERS: # set predefined readers as default
@@ -257,18 +259,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
257259
)
258260
img, err = None, []
259261
if reader is not None:
260-
if isinstance(reader, NibabelGPUReader):
261-
# TODO: handle multiple filenames later
262-
buffer = reader.read(filename[0])
263-
img = reader.get_data(buffer)
264-
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
265-
# TODO: check ensure channel first
266-
if self.ensure_channel_first:
267-
img = EnsureChannelFirst()(img)
268-
if self.image_only:
269-
return img
270-
return img, img.meta
271-
272262
img = reader.read(filename) # runtime specified reader
273263
else:
274264
for reader in self.readers[::-1]:
@@ -278,17 +268,6 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
278268
break
279269
else: # try the user designated readers
280270
try:
281-
if isinstance(reader, NibabelGPUReader):
282-
# TODO: handle multiple filenames later
283-
buffer = reader.read(filename[0])
284-
img = reader.get_data(buffer)
285-
img.meta[Key.FILENAME_OR_OBJ] = filename[0]
286-
# TODO: check ensure channel first
287-
if self.ensure_channel_first:
288-
img = EnsureChannelFirst()(img)
289-
if self.image_only:
290-
return img
291-
return img, img.meta
292271
img = reader.read(filename)
293272
except Exception as e:
294273
err.append(traceback.format_exc())
@@ -312,15 +291,15 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
312291
)
313292
img_array: NdarrayOrTensor
314293
img_array, meta_data = reader.get_data(img)
315-
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
294+
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype, device=self.device)[0]
316295
if not isinstance(meta_data, dict):
317296
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
318297
# make sure all elements in metadata are little endian
319298
meta_data = switch_endianness(meta_data, "<")
320299

321300
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
322301
img = MetaTensor.ensure_torch_and_prune_meta(
323-
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
302+
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep, device=self.device
324303
)
325304
if self.ensure_channel_first:
326305
img = EnsureChannelFirst()(img)

0 commit comments

Comments
 (0)