@@ -142,28 +142,21 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
142142 )
143143
144144
145- def _stack_images (image_list : list , meta_dict : dict ):
145+ def _stack_images (image_list : list , meta_dict : dict , to_cupy : bool = False ):
146146 if len (image_list ) <= 1 :
147147 return image_list [0 ]
148148 if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
149149 channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
150+ if to_cupy and has_cp :
151+ return cp .concatenate (image_list , axis = channel_dim )
150152 return np .concatenate (image_list , axis = channel_dim )
151153 # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
152154 meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
155+ if to_cupy and has_cp :
156+ return cp .stack (image_list , axis = 0 )
153157 return np .stack (image_list , axis = 0 )
154158
155159
156- def _stack_gpu_images (image_list : list , meta_dict : dict ):
157- if len (image_list ) <= 1 :
158- return image_list [0 ]
159- if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
160- channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
161- return cp .concatenate (image_list , axis = channel_dim )
162- # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
163- meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
164- return cp .stack (image_list , axis = 0 )
165-
166-
167160@require_pkg (pkg_name = "itk" )
168161class ITKReader (ImageReader ):
169162 """
@@ -880,12 +873,16 @@ class NibabelReader(ImageReader):
880873 Load NIfTI format images based on Nibabel library.
881874
882875 Args:
883- as_closest_canonical: if True, load the image as closest to canonical axis format.
884- squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
885876 channel_dim: the channel dimension of the input image, default is None.
886877 this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
887878 if None, `original_channel_dim` will be either `no_channel` or `-1`.
888879 most Nifti files are usually "channel last", no need to specify this argument for them.
880+ as_closest_canonical: if True, load the image as closest to canonical axis format.
881+ squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
882+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
883+ Default is False. CuPy and Kvikio are required for this option.
884+ Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
885+ and the acceleration may not be significant.
889886 kwargs: additional args for `nibabel.load` API. more details about available args:
890887 https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
891888
@@ -896,15 +893,22 @@ def __init__(
896893 channel_dim : str | int | None = None ,
897894 as_closest_canonical : bool = False ,
898895 squeeze_non_spatial_dims : bool = False ,
899- gpu_load : bool = False ,
896+ to_gpu : bool = False ,
900897 ** kwargs ,
901898 ):
902899 super ().__init__ ()
903900 self .channel_dim = float ("nan" ) if channel_dim == "no_channel" else channel_dim
904901 self .as_closest_canonical = as_closest_canonical
905902 self .squeeze_non_spatial_dims = squeeze_non_spatial_dims
906- # TODO: add warning if not have required libs
907- self .gpu_load = gpu_load
903+ if to_gpu is True :
904+ if not has_cp :
905+ warnings .warn ("CuPy is not installed, fall back to use cpu load." )
906+ to_gpu = False
907+ if not has_kvikio :
908+ warnings .warn ("Kvikio is not installed, fall back to use cpu load." )
909+ to_gpu = False
910+
911+ self .to_gpu = to_gpu
908912 self .kwargs = kwargs
909913
910914 def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
@@ -982,8 +986,8 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
982986 else :
983987 header [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
984988 _copy_compatible_dict (header , compatible_meta )
985- if self .gpu_load :
986- return _stack_gpu_images (img_array , compatible_meta ), compatible_meta
989+ if self .to_gpu :
990+ return _stack_images (img_array , compatible_meta , to_cupy = True ), compatible_meta
987991 return _stack_images (img_array , compatible_meta ), compatible_meta
988992
989993 def _get_meta_dict (self , img ) -> dict :
@@ -1047,22 +1051,18 @@ def _get_array_data(self, img, filename):
10471051 if self .gpu_load :
10481052 file_size = os .path .getsize (filename )
10491053 image = cp .empty (file_size , dtype = cp .uint8 )
1050- # suggestion from Ming: more tests, diff size
1051- # cucim + nifti
10521054 with kvikio .CuFile (filename , "r" ) as f :
10531055 f .read (image )
10541056 if filename .endswith (".gz" ):
10551057 # for compressed data, have to tansfer to CPU to decompress
10561058 # and then transfer back to GPU. It is not efficient compared to .nii file
10571059 # but it's still faster than Nibabel's default reader.
1058- # TODO: can benchmark more, it may no need to do this since we don't have to use .gz
1059- # since it's waste times especially in training
10601060 compressed_data = cp .asnumpy (image )
10611061 with gzip .GzipFile (fileobj = io .BytesIO (compressed_data )) as gz_file :
10621062 decompressed_data = gz_file .read ()
10631063
10641064 file_size = len (decompressed_data )
1065- image = cp .asarray ( np . frombuffer (decompressed_data , dtype = np .uint8 ) )
1065+ image = cp .frombuffer (decompressed_data , dtype = cp .uint8 )
10661066 data_shape = img .shape
10671067 data_offset = img .dataobj .offset
10681068 data_dtype = img .dataobj .dtype
0 commit comments