2323from pathlib import Path
2424from typing import TYPE_CHECKING , Any
2525import torch
26-
26+ from monai . data . meta_tensor import MetaTensor
2727import numpy as np
2828from torch .utils .data ._utils .collate import np_str_obj_array_pattern
2929
@@ -1038,13 +1038,22 @@ def _get_array_data(self, img):
10381038@require_pkg (pkg_name = "kvikio" )
10391039class NibabelGPUReader (NibabelReader ):
10401040
1041- def _gds_load (self , file_path ):
1042- file_size = os .path .getsize (file_path )
1041+ def read (self , filename : PathLike , ** kwargs ):
1042+ """
1043+ Read image data from specified file or files, it can read a list of images
1044+ and stack them together as multi-channel data in `get_data()`.
1045+ Note that the returned object is Nibabel image object or list of Nibabel image objects.
1046+
1047+ Args:
1048+ data: file name.
1049+
1050+ """
1051+ file_size = os .path .getsize (filename )
10431052 image = cp .empty (file_size , dtype = cp .uint8 )
1044- with kvikio .CuFile (file_path , "r" ) as f :
1053+ with kvikio .CuFile (filename , "r" ) as f :
10451054 f .read (image )
10461055
1047- if file_path .endswith (".gz" ):
1056+ if filename .endswith (".gz" ):
10481057 # for compressed data, have to tansfer to CPU to decompress
10491058 # and then transfer back to GPU. It is not efficient compared to .nii file
10501059 # but it's still faster than Nibabel's default reader.
@@ -1056,29 +1065,8 @@ def _gds_load(self, file_path):
10561065
10571066 file_size = len (decompressed_data )
10581067 image = cp .asarray (np .frombuffer (decompressed_data , dtype = np .uint8 ))
1059-
10601068 return image
10611069
1062- def read (self , data : Sequence [PathLike ] | PathLike , ** kwargs ):
1063- """
1064- Read image data from specified file or files, it can read a list of images
1065- and stack them together as multi-channel data in `get_data()`.
1066- Note that the returned object is Nibabel image object or list of Nibabel image objects.
1067-
1068- Args:
1069- data: file name or a list of file names to read.
1070-
1071- """
1072- img_ = []
1073-
1074- filenames : Sequence [PathLike ] = ensure_tuple (data )
1075- kwargs_ = self .kwargs .copy ()
1076- kwargs_ .update (kwargs )
1077- for name in filenames :
1078- img = self ._gds_load (name )
1079- img_ .append (img ) # type: ignore
1080- return img_ if len (filenames ) > 1 else img_ [0 ]
1081-
10821070 def get_data (self , img ):
10831071 """
10841072 Extract data array and metadata from loaded image and return them.
@@ -1088,39 +1076,38 @@ def get_data(self, img):
10881076 and the metadata of the first image is used to present the output metadata.
10891077
10901078 Args:
1091- img: a Nibabel image object loaded from an image file or a list of Nibabel image objects .
1079+ img: a Nibabel image object loaded from an image file.
10921080
10931081 """
1094- compatible_meta : dict = {}
1095- img_array = []
1096- for i in ensure_tuple (img ):
1097- header = self ._get_header (i )
1098- data_offset = header .get_data_offset ()
1099- data_shape = header .get_data_shape ()
1100- data_dtype = header .get_data_dtype ()
1101- affine = header .get_best_affine ()
1102- meta = dict (header )
1103- meta [MetaKeys .AFFINE ] = affine
1104- meta [MetaKeys .ORIGINAL_AFFINE ] = affine
1105- # TODO: as_closest_canonical
1106- # TODO: correct_nifti_header_if_necessary
1107- meta [MetaKeys .SPATIAL_SHAPE ] = data_shape
1108- # TODO: figure out why always RAS for NibabelReader ?
1109- # meta[MetaKeys.SPACE] = SpaceKeys.RAS
1110-
1111- data = i [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
1112- # TODO: check channel
1113- # if self.squeeze_non_spatial_dims:
1114- img_array .append (data )
1115- if self .channel_dim is None : # default to "no_channel" or -1
1116- meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = (
1117- float ("nan" ) if len (data .shape ) == len (meta [MetaKeys .SPATIAL_SHAPE ]) else - 1
1118- )
1119- else :
1120- meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
1121- _copy_compatible_dict (meta , compatible_meta )
11221082
1123- return self ._stack_images (img_array , compatible_meta ), compatible_meta
1083+ # TODO: use a formal way for device
1084+ device = torch .device ('cuda:0' if torch .cuda .is_available () else 'cpu' )
1085+
1086+ header = self ._get_header (img )
1087+ data_offset = header .get_data_offset ()
1088+ data_shape = header .get_data_shape ()
1089+ data_dtype = header .get_data_dtype ()
1090+ affine = header .get_best_affine ()
1091+ meta = dict (header )
1092+ meta [MetaKeys .AFFINE ] = affine
1093+ meta [MetaKeys .ORIGINAL_AFFINE ] = affine
1094+ # TODO: as_closest_canonical
1095+ # TODO: correct_nifti_header_if_necessary
1096+ meta [MetaKeys .SPATIAL_SHAPE ] = data_shape
1097+ # TODO: figure out why always RAS for NibabelReader ?
1098+ # meta[MetaKeys.SPACE] = SpaceKeys.RAS
1099+
1100+ data = img [data_offset :].view (data_dtype ).reshape (data_shape , order = "F" )
1101+ # TODO: check channel
1102+ # if self.squeeze_non_spatial_dims:
1103+ if self .channel_dim is None : # default to "no_channel" or -1
1104+ meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = (
1105+ float ("nan" ) if len (data .shape ) == len (meta [MetaKeys .SPATIAL_SHAPE ]) else - 1
1106+ )
1107+ else :
1108+ meta [MetaKeys .ORIGINAL_CHANNEL_DIM ] = self .channel_dim
1109+
1110+ return MetaTensor (data , affine = affine , meta = meta , device = device )
11241111
11251112 def _get_header (self , img ):
11261113 """
@@ -1139,15 +1126,6 @@ def _get_header(self, img):
11391126 pass
11401127 return header
11411128
1142- def _stack_images (self , image_list : list , meta_dict : dict ):
1143- if len (image_list ) <= 1 :
1144- return image_list [0 ]
1145- if not is_no_channel (meta_dict .get (MetaKeys .ORIGINAL_CHANNEL_DIM , None )):
1146- channel_dim = int (meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ])
1147- return torch .cat (image_list , axis = channel_dim )
1148- meta_dict [MetaKeys .ORIGINAL_CHANNEL_DIM ] = 0
1149- return torch .stack (image_list , dim = 0 )
1150-
11511129
11521130class NumpyReader (ImageReader ):
11531131 """
0 commit comments