Skip to content

Commit 84d8cf3

Browse files
enable gpu load nifti
Signed-off-by: Yiheng Wang <[email protected]>
1 parent c1ceea3 commit 84d8cf3

File tree

2 files changed

+127
-2
lines changed

2 files changed

+127
-2
lines changed

monai/data/image_reader.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
import glob
1515
import os
1616
import re
17+
import gzip
18+
import io
1719
import warnings
1820
from abc import ABC, abstractmethod
1921
from collections.abc import Callable, Iterable, Iterator, Sequence
2022
from dataclasses import dataclass
2123
from pathlib import Path
2224
from typing import TYPE_CHECKING, Any
25+
import torch
2326

2427
import numpy as np
2528
from torch.utils.data._utils.collate import np_str_obj_array_pattern
@@ -41,17 +44,21 @@
4144
import pydicom
4245
from nibabel.nifti1 import Nifti1Image
4346
from PIL import Image as PILImage
47+
import cupy as cp
48+
import kvikio
4449

45-
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True
50+
has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True
4651
else:
4752
itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
4853
nib, has_nib = optional_import("nibabel")
4954
Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image")
5055
PILImage, has_pil = optional_import("PIL.Image")
5156
pydicom, has_pydicom = optional_import("pydicom")
5257
nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True)
58+
cp, has_cp = optional_import("cupy")
59+
kvikio, has_kvikio = optional_import("kvikio")
5360

54-
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
61+
__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"]
5562

5663

5764
class ImageReader(ABC):
@@ -1024,6 +1031,122 @@ def _get_array_data(self, img):
10241031
10251032
"""
10261033
return np.asanyarray(img.dataobj, order="C")
1034+
1035+
1036+
@require_pkg(pkg_name="nibabel")
1037+
@require_pkg(pkg_name="cupy")
1038+
@require_pkg(pkg_name="kvikio")
1039+
class NibabelGPUReader(NibabelReader):
1040+
1041+
def _gds_load(self, file_path):
1042+
file_size = os.path.getsize(file_path)
1043+
image = cp.empty(file_size, dtype=cp.uint8)
1044+
with kvikio.CuFile(file_path, "r") as f:
1045+
f.read(image)
1046+
1047+
if file_path.endswith(".gz"):
1048+
# for compressed data, have to tansfer to CPU to decompress
1049+
# and then transfer back to GPU. It is not efficient compared to .nii file
1050+
# but it's still faster than Nibabel's default reader.
1051+
# TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1052+
# since it's waste times especially in training
1053+
compressed_data = cp.asnumpy(image)
1054+
with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:
1055+
decompressed_data = gz_file.read()
1056+
1057+
file_size = len(decompressed_data)
1058+
image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8))
1059+
1060+
return image
1061+
1062+
def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
1063+
"""
1064+
Read image data from specified file or files, it can read a list of images
1065+
and stack them together as multi-channel data in `get_data()`.
1066+
Note that the returned object is Nibabel image object or list of Nibabel image objects.
1067+
1068+
Args:
1069+
data: file name or a list of file names to read.
1070+
1071+
"""
1072+
img_: list[Nifti1Image] = []
1073+
1074+
filenames: Sequence[PathLike] = ensure_tuple(data)
1075+
kwargs_ = self.kwargs.copy()
1076+
kwargs_.update(kwargs)
1077+
for name in filenames:
1078+
img = self._gds_load(name)
1079+
img_.append(img) # type: ignore
1080+
return img_ if len(filenames) > 1 else img_[0]
1081+
1082+
def get_data(self, img):
1083+
"""
1084+
Extract data array and metadata from loaded image and return them.
1085+
This function returns two objects, first is numpy array of image data, second is dict of metadata.
1086+
It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
1087+
When loading a list of files, they are stacked together at a new dimension as the first dimension,
1088+
and the metadata of the first image is used to present the output metadata.
1089+
1090+
Args:
1091+
img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
1092+
1093+
"""
1094+
compatible_meta: dict = {}
1095+
img_array = []
1096+
for i in ensure_tuple(img):
1097+
header = self._get_header(i)
1098+
data_offset = header.get_data_offset()
1099+
data_shape = header.get_data_shape()
1100+
data_dtype = header.get_data_dtype()
1101+
affine = header.get_best_affine()
1102+
meta = dict(header)
1103+
meta[MetaKeys.AFFINE] = affine
1104+
meta[MetaKeys.ORIGINAL_AFFINE] = affine
1105+
# TODO: as_closest_canonical
1106+
# TODO: correct_nifti_header_if_necessary
1107+
meta[MetaKeys.SPATIAL_SHAPE] = data_shape
1108+
# TODO: figure out why always RAS for NibabelReader ?
1109+
# meta[MetaKeys.SPACE] = SpaceKeys.RAS
1110+
1111+
data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F")
1112+
# TODO: check channel
1113+
# if self.squeeze_non_spatial_dims:
1114+
img_array.append(data)
1115+
if self.channel_dim is None: # default to "no_channel" or -1
1116+
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (
1117+
float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1
1118+
)
1119+
else:
1120+
header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim
1121+
_copy_compatible_dict(header, compatible_meta)
1122+
1123+
return self._stack_images(img_array, compatible_meta), compatible_meta
1124+
1125+
def _get_header(self, img):
1126+
"""
1127+
Get the all the metadata of the image and convert to dict type.
1128+
1129+
Args:
1130+
img: a Nibabel image object loaded from an image file.
1131+
1132+
"""
1133+
header_bytes = cp.asnumpy(img[:348])
1134+
header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes))
1135+
# swap to little endian as PyTorch doesn't support big endian
1136+
try:
1137+
header = header.as_byteswapped("<")
1138+
except ValueError:
1139+
pass
1140+
return header
1141+
1142+
def _stack_images(self, image_list: list, meta_dict: dict):
1143+
if len(image_list) <= 1:
1144+
return image_list[0]
1145+
if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):
1146+
channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])
1147+
return torch.cat(image_list, axis=channel_dim)
1148+
meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0
1149+
return torch.stack(image_list, dim=0)
10271150

10281151

10291152
class NumpyReader(ImageReader):

monai/transforms/io/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ImageReader,
3636
ITKReader,
3737
NibabelReader,
38+
NibabelGPUReader,
3839
NrrdReader,
3940
NumpyReader,
4041
PILReader,
@@ -69,6 +70,7 @@
6970
"numpyreader": NumpyReader,
7071
"pilreader": PILReader,
7172
"nibabelreader": NibabelReader,
73+
"nibabelgpureader": NibabelGPUReader,
7274
}
7375

7476

0 commit comments

Comments
 (0)