@@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
418418 If provided, only the matched files will be included. For example, to include the file name
419419 "image_0001.dcm", the regular expression could be `".*image_(\\ d+).dcm"`. Default to `""`.
420420 Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
421+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
422+ Default is False. CuPy and Kvikio are required for this option.
423+ In practical use, it's recommended to add a warm up call before the actual loading.
424+ A related tutorial will be prepared in the future, and the document will be updated accordingly.
421425 kwargs: additional args for `pydicom.dcmread` API. more details about available args:
422426 https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
423427 If the `get_data` function will be called
@@ -434,6 +438,7 @@ def __init__(
434438 prune_metadata : bool = True ,
435439 label_dict : dict | None = None ,
436440 fname_regex : str = "" ,
441+ to_gpu : bool = False ,
437442 ** kwargs ,
438443 ):
439444 super ().__init__ ()
@@ -444,6 +449,33 @@ def __init__(
444449 self .prune_metadata = prune_metadata
445450 self .label_dict = label_dict
446451 self .fname_regex = fname_regex
452+ if to_gpu and (not has_cp or not has_kvikio ):
453+ warnings .warn (
454+ "PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
455+ )
456+ to_gpu = False
457+
458+ if to_gpu :
459+ self .warmup_kvikio ()
460+
461+ self .to_gpu = to_gpu
462+
463+ def warmup_kvikio (self ):
464+ """
465+ Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
466+ This can accelerate the data loading process when `to_gpu` is set to True.
467+ """
468+ if has_cp and has_kvikio :
469+ a = cp .arange (100 )
470+ with tempfile .NamedTemporaryFile () as tmp_file :
471+ tmp_file_name = tmp_file .name
472+ f = kvikio .CuFile (tmp_file_name , "w" )
473+ f .write (a )
474+ f .close ()
475+
476+ b = cp .empty_like (a )
477+ f = kvikio .CuFile (tmp_file_name , "r" )
478+ f .read (b )
447479
448480 def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
449481 """
@@ -475,19 +507,23 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
475507 img_ = []
476508
477509 filenames : Sequence [PathLike ] = ensure_tuple (data )
510+ self .filenames = filenames
478511 kwargs_ = self .kwargs .copy ()
512+ if self .to_gpu :
513+ kwargs ["defer_size" ] = "100 KB"
479514 kwargs_ .update (kwargs )
480515
481516 self .has_series = False
482517
483- for name in filenames :
518+ for i , name in enumerate ( filenames ) :
484519 name = f"{ name } "
485520 if Path (name ).is_dir ():
486521 # read DICOM series
487522 if self .fname_regex is not None :
488523 series_slcs = [slc for slc in glob .glob (os .path .join (name , "*" )) if re .match (self .fname_regex , slc )]
489524 else :
490525 series_slcs = [slc for slc in glob .glob (os .path .join (name , "*" )) if pydicom .misc .is_dicom (slc )]
526+ self .filenames [i ] = series_slcs
491527 slices = []
492528 for slc in series_slcs :
493529 try :
@@ -502,7 +538,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
502538 img_ .append (ds )
503539 return img_ if len (filenames ) > 1 else img_ [0 ]
504540
505- def _combine_dicom_series (self , data : Iterable ):
541+ def _combine_dicom_series (self , data : Iterable , filenames : Sequence [ PathLike ] ):
506542 """
507543 Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
508544 dimension as the last dimension.
@@ -522,25 +558,25 @@ def _combine_dicom_series(self, data: Iterable):
522558 """
523559 slices : list = []
524560 # for a dicom series
525- for slc_ds in data :
561+ for slc_ds , filename in zip ( data , filenames ) :
526562 if hasattr (slc_ds , "InstanceNumber" ):
527- slices .append (slc_ds )
563+ slices .append (( slc_ds , filename ) )
528564 else :
529- warnings .warn (f"slice: { slc_ds . filename } does not have InstanceNumber tag, skip it." )
530- slices = sorted (slices , key = lambda s : s .InstanceNumber )
565+ warnings .warn (f"slice: { filename } does not have InstanceNumber tag, skip it." )
566+ slices = sorted (slices , key = lambda s : s [ 0 ] .InstanceNumber )
531567
532568 if len (slices ) == 0 :
533569 raise ValueError ("the input does not have valid slices." )
534570
535- first_slice = slices [0 ]
571+ first_slice , first_filename = slices [0 ]
536572 average_distance = 0.0
537- first_array = self ._get_array_data (first_slice )
573+ first_array = self ._get_array_data (first_slice , first_filename )
538574 shape = first_array .shape
539575 spacing = getattr (first_slice , "PixelSpacing" , [1.0 , 1.0 , 1.0 ])
540576 prev_pos = getattr (first_slice , "ImagePositionPatient" , (0.0 , 0.0 , 0.0 ))[2 ]
541577 stack_array = [first_array ]
542578 for idx in range (1 , len (slices )):
543- slc_array = self ._get_array_data (slices [idx ])
579+ slc_array = self ._get_array_data (slices [idx ][ 0 ], slices [ idx ][ 1 ] )
544580 slc_shape = slc_array .shape
545581 slc_spacing = getattr (slices [idx ], "PixelSpacing" , (1.0 , 1.0 , 1.0 ))
546582 slc_pos = getattr (slices [idx ], "ImagePositionPatient" , (0.0 , 0.0 , float (idx )))[2 ]
@@ -555,7 +591,10 @@ def _combine_dicom_series(self, data: Iterable):
555591 if len (slices ) > 1 :
556592 average_distance /= len (slices ) - 1
557593 spacing .append (average_distance )
558- stack_array = np .stack (stack_array , axis = - 1 )
594+ if self .to_gpu :
595+ stack_array = cp .stack (stack_array , axis = - 1 )
596+ else :
597+ stack_array = np .stack (stack_array , axis = - 1 )
559598 stack_metadata = self ._get_meta_dict (first_slice )
560599 stack_metadata ["spacing" ] = np .asarray (spacing )
561600 if hasattr (slices [- 1 ], "ImagePositionPatient" ):
@@ -597,29 +636,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
597636 if self .has_series is True :
598637 # a list, all objects within a list belong to one dicom series
599638 if not isinstance (data [0 ], list ):
600- dicom_data .append (self ._combine_dicom_series (data ))
639+ dicom_data .append (self ._combine_dicom_series (data , self . filenames ))
601640 # a list of list, each inner list represents a dicom series
602641 else :
603- for series in data :
604- dicom_data .append (self ._combine_dicom_series (series ))
642+ for i , series in enumerate ( data ) :
643+ dicom_data .append (self ._combine_dicom_series (series , self . filenames [ i ] ))
605644 else :
606645 # a single pydicom dataset object
607646 if not isinstance (data , list ):
608647 data = [data ]
609- for d in data :
648+ for i , d in enumerate ( data ) :
610649 if hasattr (d , "SegmentSequence" ):
611- data_array , metadata = self ._get_seg_data (d )
650+ data_array , metadata = self ._get_seg_data (d , self . filenames [ i ] )
612651 else :
613- data_array = self ._get_array_data (d )
652+ data_array = self ._get_array_data (d , self . filenames [ i ] )
614653 metadata = self ._get_meta_dict (d )
615654 metadata [MetaKeys .SPATIAL_SHAPE ] = data_array .shape
616655 dicom_data .append ((data_array , metadata ))
617656
657+
658+ # TODO: the actual type is list[np.ndarray | cp.ndarray]
659+ # should figure out how to define correct types without having cupy not found error
660+ # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
618661 img_array : list [np .ndarray ] = []
619662 compatible_meta : dict = {}
620663
621664 for data_array , metadata in ensure_tuple (dicom_data ):
622- img_array .append (np .ascontiguousarray (np .swapaxes (data_array , 0 , 1 ) if self .swap_ij else data_array ))
665+ if self .swap_ij :
666+ data_array = cp .swapaxes (data_array , 0 , 1 ) if self .to_gpu else np .swapaxes (data_array , 0 , 1 )
667+ img_array .append (cp .ascontiguousarray (data_array ) if self .to_gpu else np .ascontiguousarray (data_array ))
623668 affine = self ._get_affine (metadata , self .affine_lps_to_ras )
624669 metadata [MetaKeys .SPACE ] = SpaceKeys .RAS if self .affine_lps_to_ras else SpaceKeys .LPS
625670 if self .swap_ij :
@@ -641,7 +686,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
641686
642687 _copy_compatible_dict (metadata , compatible_meta )
643688
644- return _stack_images (img_array , compatible_meta ), compatible_meta
689+ return _stack_images (img_array , compatible_meta , to_cupy = self . to_gpu ), compatible_meta
645690
646691 def _get_meta_dict (self , img ) -> dict :
647692 """
@@ -713,7 +758,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
713758 affine = orientation_ras_lps (affine )
714759 return affine
715760
716- def _get_frame_data (self , img ) -> Iterator :
761+ def _get_frame_data (self , img , filename , array_data ) -> Iterator :
717762 """
718763 yield frames and description from the segmentation image.
719764 This function is adapted from Highdicom:
@@ -752,47 +797,55 @@ def _get_frame_data(self, img) -> Iterator:
752797
753798 if not hasattr (img , "PerFrameFunctionalGroupsSequence" ):
754799 raise NotImplementedError (
755- f"To read dicom seg: { img . filename } , 'PerFrameFunctionalGroupsSequence' is required."
800+ f"To read dicom seg: { filename } , 'PerFrameFunctionalGroupsSequence' is required."
756801 )
757802
758803 frame_seg_nums = []
759804 for f in img .PerFrameFunctionalGroupsSequence :
760805 if not hasattr (f , "SegmentIdentificationSequence" ):
761806 raise NotImplementedError (
762- f"To read dicom seg: { img . filename } , 'SegmentIdentificationSequence' is required for each frame."
807+ f"To read dicom seg: { filename } , 'SegmentIdentificationSequence' is required for each frame."
763808 )
764809 frame_seg_nums .append (int (f .SegmentIdentificationSequence [0 ].ReferencedSegmentNumber ))
765810
766- frame_seg_nums_arr = np .array (frame_seg_nums )
811+ frame_seg_nums_arr = cp . array ( frame_seg_nums ) if self . to_gpu else np .array (frame_seg_nums )
767812
768813 seg_descriptions = {int (f .SegmentNumber ): f for f in img .SegmentSequence }
769814
770- for i in np .unique (frame_seg_nums_arr ):
771- indices = np .where (frame_seg_nums_arr == i )[0 ]
772- yield (img . pixel_array [indices , ...], seg_descriptions [i ])
815+ for i in np .unique (frame_seg_nums_arr ) if not self . to_gpu else cp . unique ( frame_seg_nums_arr ) :
816+ indices = np .where (frame_seg_nums_arr == i )[0 ] if not self . to_gpu else cp . where ( frame_seg_nums_arr == i )[ 0 ]
817+ yield (array_data [indices , ...], seg_descriptions [i ])
773818
774- def _get_seg_data (self , img ):
819+ def _get_seg_data (self , img , filename ):
775820 """
776821 Get the array data and metadata of the segmentation image.
777822
778823 Aegs:
779824 img: a Pydicom dataset object that has attribute "SegmentSequence".
825+ filename: the file path of the image.
780826
781827 """
782828
783829 metadata = self ._get_meta_dict (img )
784830 n_classes = len (img .SegmentSequence )
785- spatial_shape = list (img .pixel_array .shape )
831+ array_data = self ._get_array_data (img , filename )
832+ spatial_shape = list (array_data .shape )
786833 spatial_shape [0 ] = spatial_shape [0 ] // n_classes
787834
788835 if self .label_dict is not None :
789836 metadata ["labels" ] = self .label_dict
790- all_segs = np .zeros ([* spatial_shape , len (self .label_dict )])
837+ if self .to_gpu :
838+ all_segs = cp .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
839+ else :
840+ all_segs = np .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
791841 else :
792842 metadata ["labels" ] = {}
793- all_segs = np .zeros ([* spatial_shape , n_classes ])
843+ if self .to_gpu :
844+ all_segs = cp .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
845+ else :
846+ all_segs = np .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
794847
795- for i , (frames , description ) in enumerate (self ._get_frame_data (img )):
848+ for i , (frames , description ) in enumerate (self ._get_frame_data (img , filename , array_data )):
796849 segment_label = getattr (description , "SegmentLabel" , f"label_{ i } " )
797850 class_name = getattr (description , "SegmentDescription" , segment_label )
798851 if class_name not in metadata ["labels" ].keys ():
@@ -840,19 +893,51 @@ def _get_seg_data(self, img):
840893
841894 return all_segs , metadata
842895
843- def _get_array_data (self , img ):
896+ def _get_array_data (self , img , filename ):
844897 """
845898 Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
846- will be rescaled. The output data has the dtype np. float32 if the rescaling is applied.
899+ will be rescaled. The output data has the dtype float32 if the rescaling is applied.
847900
848901 Args:
849902 img: a Pydicom dataset object.
903+ filename: the file path of the image.
850904
851905 """
852906 # process Dicom series
853- if not hasattr (img , "pixel_array" ):
854- raise ValueError (f"dicom data: { img .filename } does not have pixel_array." )
855- data = img .pixel_array
907+
908+ if self .to_gpu :
909+ rows = img .Rows
910+ columns = img .Columns
911+ bits_allocated = img .BitsAllocated
912+ samples_per_pixel = img .SamplesPerPixel
913+ number_of_frames = getattr (img , 'NumberOfFrames' , 1 )
914+ pixel_representation = img .PixelRepresentation
915+
916+ if bits_allocated == 8 :
917+ dtype = cp .int8 if pixel_representation == 1 else cp .uint8
918+ elif bits_allocated == 16 :
919+ dtype = cp .int16 if pixel_representation == 1 else cp .uint16
920+ elif bits_allocated == 32 :
921+ dtype = cp .int32 if pixel_representation == 1 else cp .uint32
922+ else :
923+ raise ValueError ("Unsupported BitsAllocated value" )
924+
925+ bytes_per_pixel = bits_allocated // 8
926+ total_pixels = rows * columns * samples_per_pixel * number_of_frames
927+ expected_pixel_data_length = total_pixels * bytes_per_pixel
928+
929+ offset = img .get_item (0x7FE00010 , keep_deferred = True ).value_tell
930+
931+ with kvikio .CuFile (filename , "r" ) as f :
932+ buffer = cp .empty (expected_pixel_data_length , dtype = cp .int8 )
933+ f .read (buffer , expected_pixel_data_length , offset )
934+
935+ data = buffer .view (dtype ).reshape ((number_of_frames , rows , columns ))
936+
937+ else :
938+ if not hasattr (img , "pixel_array" ):
939+ raise ValueError (f"dicom data: { filename } does not have pixel_array." )
940+ data = img .pixel_array
856941
857942 slope , offset = 1.0 , 0.0
858943 rescale_flag = False
@@ -862,8 +947,12 @@ def _get_array_data(self, img):
862947 if hasattr (img , "RescaleIntercept" ):
863948 offset = img .RescaleIntercept
864949 rescale_flag = True
950+
865951 if rescale_flag :
866- data = data .astype (np .float32 ) * slope + offset
952+ if self .to_gpu :
953+ data = data .astype (cp .float32 ) * slope + offset
954+ else :
955+ data = data .astype (np .float32 ) * slope + offset
867956
868957 return data
869958
@@ -884,8 +973,6 @@ class NibabelReader(ImageReader):
884973 Default is False. CuPy and Kvikio are required for this option.
885974 Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886975 and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887- In practical use, it's recommended to add a warm up call before the actual loading.
888- A related tutorial will be prepared in the future, and the document will be updated accordingly.
889976 kwargs: additional args for `nibabel.load` API. more details about available args:
890977 https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
891978
0 commit comments