Skip to content

Commit abca654

Browse files
update pydicom reader
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 996e876 commit abca654

File tree

1 file changed

+125
-38
lines changed

1 file changed

+125
-38
lines changed

monai/data/image_reader.py

Lines changed: 125 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
418418
If provided, only the matched files will be included. For example, to include the file name
419419
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`.
420420
Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
421+
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
422+
Default is False. CuPy and Kvikio are required for this option.
423+
In practical use, it's recommended to add a warm up call before the actual loading.
424+
A related tutorial will be prepared in the future, and the document will be updated accordingly.
421425
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
422426
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
423427
If the `get_data` function will be called
@@ -434,6 +438,7 @@ def __init__(
434438
prune_metadata: bool = True,
435439
label_dict: dict | None = None,
436440
fname_regex: str = "",
441+
to_gpu: bool = False,
437442
**kwargs,
438443
):
439444
super().__init__()
@@ -444,6 +449,33 @@ def __init__(
444449
self.prune_metadata = prune_metadata
445450
self.label_dict = label_dict
446451
self.fname_regex = fname_regex
452+
if to_gpu and (not has_cp or not has_kvikio):
453+
warnings.warn(
454+
"PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
455+
)
456+
to_gpu = False
457+
458+
if to_gpu:
459+
self.warmup_kvikio()
460+
461+
self.to_gpu = to_gpu
462+
463+
def warmup_kvikio(self):
464+
"""
465+
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
466+
This can accelerate the data loading process when `to_gpu` is set to True.
467+
"""
468+
if has_cp and has_kvikio:
469+
a = cp.arange(100)
470+
with tempfile.NamedTemporaryFile() as tmp_file:
471+
tmp_file_name = tmp_file.name
472+
f = kvikio.CuFile(tmp_file_name, "w")
473+
f.write(a)
474+
f.close()
475+
476+
b = cp.empty_like(a)
477+
f = kvikio.CuFile(tmp_file_name, "r")
478+
f.read(b)
447479

448480
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
449481
"""
@@ -475,19 +507,23 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
475507
img_ = []
476508

477509
filenames: Sequence[PathLike] = ensure_tuple(data)
510+
self.filenames = filenames
478511
kwargs_ = self.kwargs.copy()
512+
if self.to_gpu:
513+
kwargs["defer_size"] = "100 KB"
479514
kwargs_.update(kwargs)
480515

481516
self.has_series = False
482517

483-
for name in filenames:
518+
for i, name in enumerate(filenames):
484519
name = f"{name}"
485520
if Path(name).is_dir():
486521
# read DICOM series
487522
if self.fname_regex is not None:
488523
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)]
489524
else:
490525
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)]
526+
self.filenames[i] = series_slcs
491527
slices = []
492528
for slc in series_slcs:
493529
try:
@@ -502,7 +538,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
502538
img_.append(ds)
503539
return img_ if len(filenames) > 1 else img_[0]
504540

505-
def _combine_dicom_series(self, data: Iterable):
541+
def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
506542
"""
507543
Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
508544
dimension as the last dimension.
@@ -522,25 +558,25 @@ def _combine_dicom_series(self, data: Iterable):
522558
"""
523559
slices: list = []
524560
# for a dicom series
525-
for slc_ds in data:
561+
for slc_ds, filename in zip(data, filenames):
526562
if hasattr(slc_ds, "InstanceNumber"):
527-
slices.append(slc_ds)
563+
slices.append((slc_ds, filename))
528564
else:
529-
warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.")
530-
slices = sorted(slices, key=lambda s: s.InstanceNumber)
565+
warnings.warn(f"slice: {filename} does not have InstanceNumber tag, skip it.")
566+
slices = sorted(slices, key=lambda s: s[0].InstanceNumber)
531567

532568
if len(slices) == 0:
533569
raise ValueError("the input does not have valid slices.")
534570

535-
first_slice = slices[0]
571+
first_slice, first_filename = slices[0]
536572
average_distance = 0.0
537-
first_array = self._get_array_data(first_slice)
573+
first_array = self._get_array_data(first_slice, first_filename)
538574
shape = first_array.shape
539575
spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0])
540576
prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2]
541577
stack_array = [first_array]
542578
for idx in range(1, len(slices)):
543-
slc_array = self._get_array_data(slices[idx])
579+
slc_array = self._get_array_data(slices[idx][0], slices[idx][1])
544580
slc_shape = slc_array.shape
545581
slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0))
546582
slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2]
@@ -555,7 +591,10 @@ def _combine_dicom_series(self, data: Iterable):
555591
if len(slices) > 1:
556592
average_distance /= len(slices) - 1
557593
spacing.append(average_distance)
558-
stack_array = np.stack(stack_array, axis=-1)
594+
if self.to_gpu:
595+
stack_array = cp.stack(stack_array, axis=-1)
596+
else:
597+
stack_array = np.stack(stack_array, axis=-1)
559598
stack_metadata = self._get_meta_dict(first_slice)
560599
stack_metadata["spacing"] = np.asarray(spacing)
561600
if hasattr(slices[-1], "ImagePositionPatient"):
@@ -597,29 +636,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
597636
if self.has_series is True:
598637
# a list, all objects within a list belong to one dicom series
599638
if not isinstance(data[0], list):
600-
dicom_data.append(self._combine_dicom_series(data))
639+
dicom_data.append(self._combine_dicom_series(data, self.filenames))
601640
# a list of list, each inner list represents a dicom series
602641
else:
603-
for series in data:
604-
dicom_data.append(self._combine_dicom_series(series))
642+
for i, series in enumerate(data):
643+
dicom_data.append(self._combine_dicom_series(series, self.filenames[i]))
605644
else:
606645
# a single pydicom dataset object
607646
if not isinstance(data, list):
608647
data = [data]
609-
for d in data:
648+
for i, d in enumerate(data):
610649
if hasattr(d, "SegmentSequence"):
611-
data_array, metadata = self._get_seg_data(d)
650+
data_array, metadata = self._get_seg_data(d, self.filenames[i])
612651
else:
613-
data_array = self._get_array_data(d)
652+
data_array = self._get_array_data(d, self.filenames[i])
614653
metadata = self._get_meta_dict(d)
615654
metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape
616655
dicom_data.append((data_array, metadata))
617656

657+
658+
# TODO: the actual type is list[np.ndarray | cp.ndarray]
659+
# should figure out how to define correct types without having cupy not found error
660+
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
618661
img_array: list[np.ndarray] = []
619662
compatible_meta: dict = {}
620663

621664
for data_array, metadata in ensure_tuple(dicom_data):
622-
img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array))
665+
if self.swap_ij:
666+
data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1)
667+
img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array))
623668
affine = self._get_affine(metadata, self.affine_lps_to_ras)
624669
metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS
625670
if self.swap_ij:
@@ -641,7 +686,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
641686

642687
_copy_compatible_dict(metadata, compatible_meta)
643688

644-
return _stack_images(img_array, compatible_meta), compatible_meta
689+
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta
645690

646691
def _get_meta_dict(self, img) -> dict:
647692
"""
@@ -713,7 +758,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
713758
affine = orientation_ras_lps(affine)
714759
return affine
715760

716-
def _get_frame_data(self, img) -> Iterator:
761+
def _get_frame_data(self, img, filename, array_data) -> Iterator:
717762
"""
718763
yield frames and description from the segmentation image.
719764
This function is adapted from Highdicom:
@@ -752,47 +797,55 @@ def _get_frame_data(self, img) -> Iterator:
752797

753798
if not hasattr(img, "PerFrameFunctionalGroupsSequence"):
754799
raise NotImplementedError(
755-
f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required."
800+
f"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required."
756801
)
757802

758803
frame_seg_nums = []
759804
for f in img.PerFrameFunctionalGroupsSequence:
760805
if not hasattr(f, "SegmentIdentificationSequence"):
761806
raise NotImplementedError(
762-
f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame."
807+
f"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame."
763808
)
764809
frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber))
765810

766-
frame_seg_nums_arr = np.array(frame_seg_nums)
811+
frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums)
767812

768813
seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence}
769814

770-
for i in np.unique(frame_seg_nums_arr):
771-
indices = np.where(frame_seg_nums_arr == i)[0]
772-
yield (img.pixel_array[indices, ...], seg_descriptions[i])
815+
for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr):
816+
indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0]
817+
yield (array_data[indices, ...], seg_descriptions[i])
773818

774-
def _get_seg_data(self, img):
819+
def _get_seg_data(self, img, filename):
775820
"""
776821
Get the array data and metadata of the segmentation image.
777822
778823
Aegs:
779824
img: a Pydicom dataset object that has attribute "SegmentSequence".
825+
filename: the file path of the image.
780826
781827
"""
782828

783829
metadata = self._get_meta_dict(img)
784830
n_classes = len(img.SegmentSequence)
785-
spatial_shape = list(img.pixel_array.shape)
831+
array_data = self._get_array_data(img, filename)
832+
spatial_shape = list(array_data.shape)
786833
spatial_shape[0] = spatial_shape[0] // n_classes
787834

788835
if self.label_dict is not None:
789836
metadata["labels"] = self.label_dict
790-
all_segs = np.zeros([*spatial_shape, len(self.label_dict)])
837+
if self.to_gpu:
838+
all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
839+
else:
840+
all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
791841
else:
792842
metadata["labels"] = {}
793-
all_segs = np.zeros([*spatial_shape, n_classes])
843+
if self.to_gpu:
844+
all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)
845+
else:
846+
all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)
794847

795-
for i, (frames, description) in enumerate(self._get_frame_data(img)):
848+
for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)):
796849
segment_label = getattr(description, "SegmentLabel", f"label_{i}")
797850
class_name = getattr(description, "SegmentDescription", segment_label)
798851
if class_name not in metadata["labels"].keys():
@@ -840,19 +893,51 @@ def _get_seg_data(self, img):
840893

841894
return all_segs, metadata
842895

843-
def _get_array_data(self, img):
896+
def _get_array_data(self, img, filename):
844897
"""
845898
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
846-
will be rescaled. The output data has the dtype np.float32 if the rescaling is applied.
899+
will be rescaled. The output data has the dtype float32 if the rescaling is applied.
847900
848901
Args:
849902
img: a Pydicom dataset object.
903+
filename: the file path of the image.
850904
851905
"""
852906
# process Dicom series
853-
if not hasattr(img, "pixel_array"):
854-
raise ValueError(f"dicom data: {img.filename} does not have pixel_array.")
855-
data = img.pixel_array
907+
908+
if self.to_gpu:
909+
rows = img.Rows
910+
columns = img.Columns
911+
bits_allocated = img.BitsAllocated
912+
samples_per_pixel = img.SamplesPerPixel
913+
number_of_frames = getattr(img, 'NumberOfFrames', 1)
914+
pixel_representation = img.PixelRepresentation
915+
916+
if bits_allocated == 8:
917+
dtype = cp.int8 if pixel_representation == 1 else cp.uint8
918+
elif bits_allocated == 16:
919+
dtype = cp.int16 if pixel_representation == 1 else cp.uint16
920+
elif bits_allocated == 32:
921+
dtype = cp.int32 if pixel_representation == 1 else cp.uint32
922+
else:
923+
raise ValueError("Unsupported BitsAllocated value")
924+
925+
bytes_per_pixel = bits_allocated // 8
926+
total_pixels = rows * columns * samples_per_pixel * number_of_frames
927+
expected_pixel_data_length = total_pixels * bytes_per_pixel
928+
929+
offset = img.get_item(0x7FE00010, keep_deferred=True).value_tell
930+
931+
with kvikio.CuFile(filename, "r") as f:
932+
buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8)
933+
f.read(buffer, expected_pixel_data_length, offset)
934+
935+
data = buffer.view(dtype).reshape((number_of_frames, rows, columns))
936+
937+
else:
938+
if not hasattr(img, "pixel_array"):
939+
raise ValueError(f"dicom data: {filename} does not have pixel_array.")
940+
data = img.pixel_array
856941

857942
slope, offset = 1.0, 0.0
858943
rescale_flag = False
@@ -862,8 +947,12 @@ def _get_array_data(self, img):
862947
if hasattr(img, "RescaleIntercept"):
863948
offset = img.RescaleIntercept
864949
rescale_flag = True
950+
865951
if rescale_flag:
866-
data = data.astype(np.float32) * slope + offset
952+
if self.to_gpu:
953+
data = data.astype(cp.float32) * slope + offset
954+
else:
955+
data = data.astype(np.float32) * slope + offset
867956

868957
return data
869958

@@ -884,8 +973,6 @@ class NibabelReader(ImageReader):
884973
Default is False. CuPy and Kvikio are required for this option.
885974
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886975
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887-
In practical use, it's recommended to add a warm up call before the actual loading.
888-
A related tutorial will be prepared in the future, and the document will be updated accordingly.
889976
kwargs: additional args for `nibabel.load` API. more details about available args:
890977
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
891978

0 commit comments

Comments
 (0)