|
14 | 14 | import glob |
15 | 15 | import os |
16 | 16 | import re |
| 17 | +import gzip |
| 18 | +import io |
17 | 19 | import warnings |
18 | 20 | from abc import ABC, abstractmethod |
19 | 21 | from collections.abc import Callable, Iterable, Iterator, Sequence |
20 | 22 | from dataclasses import dataclass |
21 | 23 | from pathlib import Path |
22 | 24 | from typing import TYPE_CHECKING, Any |
| 25 | +import torch |
23 | 26 |
|
24 | 27 | import numpy as np |
25 | 28 | from torch.utils.data._utils.collate import np_str_obj_array_pattern |
|
41 | 44 | import pydicom |
42 | 45 | from nibabel.nifti1 import Nifti1Image |
43 | 46 | from PIL import Image as PILImage |
| 47 | + import cupy as cp |
| 48 | + import kvikio |
44 | 49 |
|
45 | | - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True |
| 50 | + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = has_cp = has_kvikio = True |
46 | 51 | else: |
47 | 52 | itk, has_itk = optional_import("itk", allow_namespace_pkg=True) |
48 | 53 | nib, has_nib = optional_import("nibabel") |
49 | 54 | Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") |
50 | 55 | PILImage, has_pil = optional_import("PIL.Image") |
51 | 56 | pydicom, has_pydicom = optional_import("pydicom") |
52 | 57 | nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) |
| 58 | + cp, has_cp = optional_import("cupy") |
| 59 | + kvikio, has_kvikio = optional_import("kvikio") |
53 | 60 |
|
54 | | -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] |
| 61 | +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NibabelGPUReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] |
55 | 62 |
|
56 | 63 |
|
57 | 64 | class ImageReader(ABC): |
@@ -1024,6 +1031,122 @@ def _get_array_data(self, img): |
1024 | 1031 |
|
1025 | 1032 | """ |
1026 | 1033 | return np.asanyarray(img.dataobj, order="C") |
| 1034 | + |
| 1035 | + |
| 1036 | +@require_pkg(pkg_name="nibabel") |
| 1037 | +@require_pkg(pkg_name="cupy") |
| 1038 | +@require_pkg(pkg_name="kvikio") |
| 1039 | +class NibabelGPUReader(NibabelReader): |
| 1040 | + |
| 1041 | + def _gds_load(self, file_path): |
| 1042 | + file_size = os.path.getsize(file_path) |
| 1043 | + image = cp.empty(file_size, dtype=cp.uint8) |
| 1044 | + with kvikio.CuFile(file_path, "r") as f: |
| 1045 | + f.read(image) |
| 1046 | + |
| 1047 | + if file_path.endswith(".gz"): |
| 1048 | + # for compressed data, have to tansfer to CPU to decompress |
| 1049 | + # and then transfer back to GPU. It is not efficient compared to .nii file |
| 1050 | + # but it's still faster than Nibabel's default reader. |
| 1051 | + # TODO: can benchmark more, it may no need to do this since we don't have to use .gz |
| 1052 | + # since it's waste times especially in training |
| 1053 | + compressed_data = cp.asnumpy(image) |
| 1054 | + with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file: |
| 1055 | + decompressed_data = gz_file.read() |
| 1056 | + |
| 1057 | + file_size = len(decompressed_data) |
| 1058 | + image = cp.asarray(np.frombuffer(decompressed_data, dtype=np.uint8)) |
| 1059 | + |
| 1060 | + return image |
| 1061 | + |
| 1062 | + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): |
| 1063 | + """ |
| 1064 | + Read image data from specified file or files, it can read a list of images |
| 1065 | + and stack them together as multi-channel data in `get_data()`. |
| 1066 | + Note that the returned object is Nibabel image object or list of Nibabel image objects. |
| 1067 | +
|
| 1068 | + Args: |
| 1069 | + data: file name or a list of file names to read. |
| 1070 | +
|
| 1071 | + """ |
| 1072 | + img_: list[Nifti1Image] = [] |
| 1073 | + |
| 1074 | + filenames: Sequence[PathLike] = ensure_tuple(data) |
| 1075 | + kwargs_ = self.kwargs.copy() |
| 1076 | + kwargs_.update(kwargs) |
| 1077 | + for name in filenames: |
| 1078 | + img = self._gds_load(name) |
| 1079 | + img_.append(img) # type: ignore |
| 1080 | + return img_ if len(filenames) > 1 else img_[0] |
| 1081 | + |
| 1082 | + def get_data(self, img): |
| 1083 | + """ |
| 1084 | + Extract data array and metadata from loaded image and return them. |
| 1085 | + This function returns two objects, first is numpy array of image data, second is dict of metadata. |
| 1086 | + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. |
| 1087 | + When loading a list of files, they are stacked together at a new dimension as the first dimension, |
| 1088 | + and the metadata of the first image is used to present the output metadata. |
| 1089 | +
|
| 1090 | + Args: |
| 1091 | + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. |
| 1092 | +
|
| 1093 | + """ |
| 1094 | + compatible_meta: dict = {} |
| 1095 | + img_array = [] |
| 1096 | + for i in ensure_tuple(img): |
| 1097 | + header = self._get_header(i) |
| 1098 | + data_offset = header.get_data_offset() |
| 1099 | + data_shape = header.get_data_shape() |
| 1100 | + data_dtype = header.get_data_dtype() |
| 1101 | + affine = header.get_best_affine() |
| 1102 | + meta = dict(header) |
| 1103 | + meta[MetaKeys.AFFINE] = affine |
| 1104 | + meta[MetaKeys.ORIGINAL_AFFINE] = affine |
| 1105 | + # TODO: as_closest_canonical |
| 1106 | + # TODO: correct_nifti_header_if_necessary |
| 1107 | + meta[MetaKeys.SPATIAL_SHAPE] = data_shape |
| 1108 | + # TODO: figure out why always RAS for NibabelReader ? |
| 1109 | + # meta[MetaKeys.SPACE] = SpaceKeys.RAS |
| 1110 | + |
| 1111 | + data = i[data_offset:].view(data_dtype).reshape(data_shape, order="F") |
| 1112 | + # TODO: check channel |
| 1113 | + # if self.squeeze_non_spatial_dims: |
| 1114 | + img_array.append(data) |
| 1115 | + if self.channel_dim is None: # default to "no_channel" or -1 |
| 1116 | + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( |
| 1117 | + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 |
| 1118 | + ) |
| 1119 | + else: |
| 1120 | + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim |
| 1121 | + _copy_compatible_dict(header, compatible_meta) |
| 1122 | + |
| 1123 | + return self._stack_images(img_array, compatible_meta), compatible_meta |
| 1124 | + |
| 1125 | + def _get_header(self, img): |
| 1126 | + """ |
| 1127 | + Get the all the metadata of the image and convert to dict type. |
| 1128 | +
|
| 1129 | + Args: |
| 1130 | + img: a Nibabel image object loaded from an image file. |
| 1131 | +
|
| 1132 | + """ |
| 1133 | + header_bytes = cp.asnumpy(img[:348]) |
| 1134 | + header = nib.Nifti1Header.from_fileobj(io.BytesIO(header_bytes)) |
| 1135 | + # swap to little endian as PyTorch doesn't support big endian |
| 1136 | + try: |
| 1137 | + header = header.as_byteswapped("<") |
| 1138 | + except ValueError: |
| 1139 | + pass |
| 1140 | + return header |
| 1141 | + |
| 1142 | + def _stack_images(self, image_list: list, meta_dict: dict): |
| 1143 | + if len(image_list) <= 1: |
| 1144 | + return image_list[0] |
| 1145 | + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): |
| 1146 | + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) |
| 1147 | + return torch.cat(image_list, axis=channel_dim) |
| 1148 | + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 |
| 1149 | + return torch.stack(image_list, dim=0) |
1027 | 1150 |
|
1028 | 1151 |
|
1029 | 1152 | class NumpyReader(ImageReader): |
|
0 commit comments