1212from __future__ import annotations
1313
1414import glob
15+ import gzip
16+ import io
1517import os
1618import re
19+ import tempfile
1720import warnings
1821from abc import ABC , abstractmethod
1922from collections .abc import Callable , Iterable , Iterator , Sequence
5154 pydicom , has_pydicom = optional_import ("pydicom" )
5255 nrrd , has_nrrd = optional_import ("nrrd" , allow_namespace_pkg = True )
5356
57+ cp , has_cp = optional_import ("cupy" )
58+ kvikio , has_kvikio = optional_import ("kvikio" )
59+
5460__all__ = ["ImageReader" , "ITKReader" , "NibabelReader" , "NumpyReader" , "PILReader" , "PydicomReader" , "NrrdReader" ]
5561
5662
@@ -137,14 +143,18 @@ def _copy_compatible_dict(from_dict: dict, to_dict: dict):
137143 )
138144
139145
140- def _stack_images (image_list : list , meta_dict : dict ):
146+ def _stack_images (image_list : list , meta_dict : dict , to_cupy : bool = False ):
141147 if len (image_list ) <= 1 :
142148 return image_list [0 ]
143149 if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
144150 channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
151+ if to_cupy and has_cp :
152+ return cp .concatenate (image_list , axis = channel_dim )
145153 return np .concatenate (image_list , axis = channel_dim )
146154 # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified
147155 meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
156+ if to_cupy and has_cp :
157+ return cp .stack (image_list , axis = 0 )
148158 return np .stack (image_list , axis = 0 )
149159
150160
@@ -864,12 +874,18 @@ class NibabelReader(ImageReader):
864874 Load NIfTI format images based on Nibabel library.
865875
866876 Args:
867- as_closest_canonical: if True, load the image as closest to canonical axis format.
868- squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
869877 channel_dim: the channel dimension of the input image, default is None.
870878 this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.
871879 if None, `original_channel_dim` will be either `no_channel` or `-1`.
872880 most Nifti files are usually "channel last", no need to specify this argument for them.
881+ as_closest_canonical: if True, load the image as closest to canonical axis format.
882+ squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)
883+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
884+ Default is False. CuPy and Kvikio are required for this option.
885+ Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886+ and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887+ In practical use, it's recommended to add a warm up call before the actual loading.
888+ A related tutorial will be prepared in the future, and the document will be updated accordingly.
873889 kwargs: additional args for `nibabel.load` API. more details about available args:
874890 https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
875891
@@ -880,14 +896,42 @@ def __init__(
880896 channel_dim : str | int | None = None ,
881897 as_closest_canonical : bool = False ,
882898 squeeze_non_spatial_dims : bool = False ,
899+ to_gpu : bool = False ,
883900 ** kwargs ,
884901 ):
885902 super ().__init__ ()
886903 self .channel_dim = float ("nan" ) if channel_dim == "no_channel" else channel_dim
887904 self .as_closest_canonical = as_closest_canonical
888905 self .squeeze_non_spatial_dims = squeeze_non_spatial_dims
906+ if to_gpu and (not has_cp or not has_kvikio ):
907+ warnings .warn (
908+ "NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
909+ )
910+ to_gpu = False
911+
912+ if to_gpu :
913+ self .warmup_kvikio ()
914+
915+ self .to_gpu = to_gpu
889916 self .kwargs = kwargs
890917
918+ def warmup_kvikio (self ):
919+ """
920+ Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
921+ This can accelerate the data loading process when `to_gpu` is set to True.
922+ """
923+ if has_cp and has_kvikio :
924+ a = cp .arange (100 )
925+ with tempfile .NamedTemporaryFile () as tmp_file :
926+ tmp_file_name = tmp_file .name
927+ f = kvikio .CuFile (tmp_file_name , "w" )
928+ f .write (a )
929+ f .close ()
930+
931+ b = cp .empty_like (a )
932+ f = kvikio .CuFile (tmp_file_name , "r" )
933+ f .read (b )
934+
891935 def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
892936 """
893937 Verify whether the specified file or files format is supported by Nibabel reader.
@@ -916,6 +960,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
916960 img_ : list [Nifti1Image ] = []
917961
918962 filenames : Sequence [PathLike ] = ensure_tuple (data )
963+ self .filenames = filenames
919964 kwargs_ = self .kwargs .copy ()
920965 kwargs_ .update (kwargs )
921966 for name in filenames :
@@ -936,10 +981,13 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
936981 img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.
937982
938983 """
984+ # TODO: the actual type is list[np.ndarray | cp.ndarray]
985+ # should figure out how to define correct types without having cupy not found error
986+ # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
939987 img_array : list [np .ndarray ] = []
940988 compatible_meta : dict = {}
941989
942- for i in ensure_tuple (img ):
990+ for i , filename in zip ( ensure_tuple (img ), self . filenames ):
943991 header = self ._get_meta_dict (i )
944992 header [MetaKeys .AFFINE ] = self ._get_affine (i )
945993 header [MetaKeys .ORIGINAL_AFFINE ] = self ._get_affine (i )
@@ -949,7 +997,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
949997 header [MetaKeys .AFFINE ] = self ._get_affine (i )
950998 header [MetaKeys .SPATIAL_SHAPE ] = self ._get_spatial_shape (i )
951999 header [MetaKeys .SPACE ] = SpaceKeys .RAS
952- data = self ._get_array_data (i )
1000+ data = self ._get_array_data (i , filename )
9531001 if self .squeeze_non_spatial_dims :
9541002 for d in range (len (data .shape ), len (header [MetaKeys .SPATIAL_SHAPE ]), - 1 ):
9551003 if data .shape [d - 1 ] == 1 :
@@ -963,7 +1011,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
9631011 header [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
9641012 _copy_compatible_dict (header , compatible_meta )
9651013
966- return _stack_images (img_array , compatible_meta ), compatible_meta
1014+ return _stack_images (img_array , compatible_meta , to_cupy = self . to_gpu ), compatible_meta
9671015
9681016 def _get_meta_dict (self , img ) -> dict :
9691017 """
@@ -1015,14 +1063,34 @@ def _get_spatial_shape(self, img):
10151063 spatial_rank = max (min (ndim , 3 ), 1 )
10161064 return np .asarray (size [:spatial_rank ])
10171065
1018- def _get_array_data (self , img ):
1066+ def _get_array_data (self , img , filename ):
10191067 """
10201068 Get the raw array data of the image, converted to Numpy array.
10211069
10221070 Args:
10231071 img: a Nibabel image object loaded from an image file.
1024-
1025- """
1072+ filename: file name of the image.
1073+
1074+ """
1075+ if self .to_gpu :
1076+ file_size = os .path .getsize (filename )
1077+ image = cp .empty (file_size , dtype = cp .uint8 )
1078+ with kvikio .CuFile (filename , "r" ) as f :
1079+ f .read (image )
1080+ if filename .endswith (".nii.gz" ):
1081+ # for compressed data, have to tansfer to CPU to decompress
1082+ # and then transfer back to GPU. It is not efficient compared to .nii file
1083+ # and may be slower than CPU loading in some cases.
1084+ warnings .warn ("Loading compressed NIfTI file into GPU may not be efficient." )
1085+ compressed_data = cp .asnumpy (image )
1086+ with gzip .GzipFile (fileobj = io .BytesIO (compressed_data )) as gz_file :
1087+ decompressed_data = gz_file .read ()
1088+
1089+ image = cp .frombuffer (decompressed_data , dtype = cp .uint8 )
1090+ data_shape = img .shape
1091+ data_offset = img .dataobj .offset
1092+ data_dtype = img .dataobj .dtype
1093+ return image [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
10261094 return np .asanyarray (img .dataobj , order = "C" )
10271095
10281096
0 commit comments