5858 cp , has_cp = optional_import ("cupy" )
5959 kvikio , has_kvikio = optional_import ("kvikio" )
6060
61- __all__ = ["ImageReader" , "ITKReader" , "NibabelReader" , "NibabelGPUReader" , " NumpyReader" , "PILReader" , "PydicomReader" , "NrrdReader" ]
61+ __all__ = ["ImageReader" , "ITKReader" , "NibabelReader" , "NumpyReader" , "PILReader" , "PydicomReader" , "NrrdReader" ]
6262
6363
6464class ImageReader (ABC ):
@@ -155,6 +155,17 @@ def _stack_images(image_list: list, meta_dict: dict):
155155 return np .stack (image_list , axis = 0 )
156156
157157
158+ def _stack_gpu_images (image_list : list , meta_dict : dict ):
159+ if len (image_list ) <= 1 :
160+ return image_list [0 ]
161+ if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
162+ channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
163+ return cp .concatenate (image_list , axis = channel_dim )
164+ # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
165+ meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
166+ return cp .stack (image_list , axis = 0 )
167+
168+
158169@require_pkg (pkg_name = "itk" )
159170class ITKReader (ImageReader ):
160171 """
@@ -887,12 +898,15 @@ def __init__(
887898 channel_dim : str | int | None = None ,
888899 as_closest_canonical : bool = False ,
889900 squeeze_non_spatial_dims : bool = False ,
901+ gpu_load : bool = False ,
890902 ** kwargs ,
891903 ):
892904 super ().__init__ ()
893905 self .channel_dim = float ("nan" ) if channel_dim == "no_channel" else channel_dim
894906 self .as_closest_canonical = as_closest_canonical
895907 self .squeeze_non_spatial_dims = squeeze_non_spatial_dims
908+ # TODO: add warning if not have required libs
909+ self .gpu_load = gpu_load
896910 self .kwargs = kwargs
897911
898912 def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
@@ -923,6 +937,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
923937 img_ : list [Nifti1Image ] = []
924938
925939 filenames : Sequence [PathLike ] = ensure_tuple (data )
940+ self .filenames = filenames
926941 kwargs_ = self .kwargs .copy ()
927942 kwargs_ .update (kwargs )
928943 for name in filenames :
@@ -946,7 +961,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
946961 img_array : list [np .ndarray ] = []
947962 compatible_meta : dict = {}
948963
949- for i in ensure_tuple (img ):
964+ for i , filename in zip ( ensure_tuple (img ), self . filenames ):
950965 header = self ._get_meta_dict (i )
951966 header [MetaKeys .AFFINE ] = self ._get_affine (i )
952967 header [MetaKeys .ORIGINAL_AFFINE ] = self ._get_affine (i )
@@ -956,7 +971,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
956971 header [MetaKeys .AFFINE ] = self ._get_affine (i )
957972 header [MetaKeys .SPATIAL_SHAPE ] = self ._get_spatial_shape (i )
958973 header [MetaKeys .SPACE ] = SpaceKeys .RAS
959- data = self ._get_array_data (i )
974+ data = self ._get_array_data (i , filename )
960975 if self .squeeze_non_spatial_dims :
961976 for d in range (len (data .shape ), len (header [MetaKeys .SPATIAL_SHAPE ]), - 1 ):
962977 if data .shape [d - 1 ] == 1 :
@@ -969,7 +984,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
969984 else :
970985 header [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
971986 _copy_compatible_dict (header , compatible_meta )
972-
987+ if self .gpu_load :
988+ return _stack_gpu_images (img_array , compatible_meta ), compatible_meta
973989 return _stack_images (img_array , compatible_meta ), compatible_meta
974990
975991 def _get_meta_dict (self , img ) -> dict :
@@ -1022,111 +1038,40 @@ def _get_spatial_shape(self, img):
10221038 spatial_rank = max (min (ndim , 3 ), 1 )
10231039 return np .asarray (size [:spatial_rank ])
10241040
1025- def _get_array_data (self , img ):
1041+ def _get_array_data (self , img , filename ):
10261042 """
10271043 Get the raw array data of the image, converted to Numpy array.
10281044
10291045 Args:
10301046 img: a Nibabel image object loaded from an image file.
10311047
10321048 """
1049+ if self .gpu_load :
1050+ file_size = os .path .getsize (filename )
1051+ image = cp .empty (file_size , dtype = cp .uint8 )
1052+ # suggestion from Ming: more tests, diff size
1053+ # cucim + nifti
1054+ with kvikio .CuFile (filename , "r" ) as f :
1055+ f .read (image )
1056+ if filename .endswith (".gz" ):
1057+ # for compressed data, have to tansfer to CPU to decompress
1058+ # and then transfer back to GPU. It is not efficient compared to .nii file
1059+ # but it's still faster than Nibabel's default reader.
1060+ # TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1061+ # since it's waste times especially in training
1062+ compressed_data = cp .asnumpy (image )
1063+ with gzip .GzipFile (fileobj = io .BytesIO (compressed_data )) as gz_file :
1064+ decompressed_data = gz_file .read ()
1065+
1066+ file_size = len (decompressed_data )
1067+ image = cp .asarray (np .frombuffer (decompressed_data , dtype = np .uint8 ))
1068+ data_shape = img .shape
1069+ data_offset = img .dataobj .offset
1070+ data_dtype = img .dataobj .dtype
1071+ return image [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
10331072 return np .asanyarray (img .dataobj , order = "C" )
10341073
10351074
1036- @require_pkg (pkg_name = "nibabel" )
1037- @require_pkg (pkg_name = "cupy" )
1038- @require_pkg (pkg_name = "kvikio" )
1039- class NibabelGPUReader (NibabelReader ):
1040-
1041- def read (self , filename : PathLike , ** kwargs ):
1042- """
1043- Read image data from specified file or files, it can read a list of images
1044- and stack them together as multi-channel data in `get_data()`.
1045- Note that the returned object is Nibabel image object or list of Nibabel image objects.
1046-
1047- Args:
1048- data: file name.
1049-
1050- """
1051- file_size = os .path .getsize (filename )
1052- image = cp .empty (file_size , dtype = cp .uint8 )
1053- with kvikio .CuFile (filename , "r" ) as f :
1054- f .read (image )
1055-
1056- if filename .endswith (".gz" ):
1057- # for compressed data, have to tansfer to CPU to decompress
1058- # and then transfer back to GPU. It is not efficient compared to .nii file
1059- # but it's still faster than Nibabel's default reader.
1060- # TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1061- # since it's waste times especially in training
1062- compressed_data = cp .asnumpy (image )
1063- with gzip .GzipFile (fileobj = io .BytesIO (compressed_data )) as gz_file :
1064- decompressed_data = gz_file .read ()
1065-
1066- file_size = len (decompressed_data )
1067- image = cp .asarray (np .frombuffer (decompressed_data , dtype = np .uint8 ))
1068- return image
1069-
1070- def get_data (self , img ):
1071- """
1072- Extract data array and metadata from loaded image and return them.
1073- This function returns two objects, first is numpy array of image data, second is dict of metadata.
1074- It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.
1075- When loading a list of files, they are stacked together at a new dimension as the first dimension,
1076- and the metadata of the first image is used to present the output metadata.
1077-
1078- Args:
1079- img: a Nibabel image object loaded from an image file.
1080-
1081- """
1082-
1083- # TODO: use a formal way for device
1084- device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
1085-
1086- header = self ._get_header (img )
1087- data_offset = header .get_data_offset ()
1088- data_shape = header .get_data_shape ()
1089- data_dtype = header .get_data_dtype ()
1090- affine = header .get_best_affine ()
1091- meta = {}
1092- meta [MetaKeys .AFFINE ] = affine
1093- meta [MetaKeys .ORIGINAL_AFFINE ] = affine
1094- # TODO: as_closest_canonical
1095- # TODO: correct_nifti_header_if_necessary
1096- meta [MetaKeys .SPATIAL_SHAPE ] = data_shape
1097- # TODO: figure out why always RAS for NibabelReader ?
1098- # meta[MetaKeys.SPACE] = SpaceKeys.RAS
1099-
1100- data = img [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
1101- # TODO: check channel
1102- # if self.squeeze_non_spatial_dims:
1103- if self .channel_dim is None : # default to "no_channel" or -1
1104- meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = (
1105- float ("nan" ) if len (data .shape ) == len (meta [MetaKeys .SPATIAL_SHAPE ]) else - 1
1106- )
1107- else :
1108- meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
1109-
1110- return MetaTensor (data , affine = affine , meta = meta , device = device )
1111-
1112- def _get_header (self , img ):
1113- """
1114- Get the all the metadata of the image and convert to dict type.
1115-
1116- Args:
1117- img: a Nibabel image object loaded from an image file.
1118-
1119- """
1120- header_bytes = cp .asnumpy (img [:348 ])
1121- header = nib .Nifti1Header .from_fileobj (io .BytesIO (header_bytes ))
1122- # swap to little endian as PyTorch doesn't support big endian
1123- try :
1124- header = header .as_byteswapped ("<" )
1125- except ValueError :
1126- pass
1127- return header
1128-
1129-
11301075class NumpyReader (ImageReader ):
11311076 """
11321077 Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.
0 commit comments