@@ -257,14 +257,14 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
257257 # End Block
258258 self .network_weights = self .predictor .network
259259
260- def forward (self , x : Union [ MetaTensor , tuple [ MetaTensor ]]) :
260+ def forward (self , x : MetaTensor ) -> MetaTensor :
261261 """
262262 Forward pass for the nnUNet model.
263263
264264 :no-index:
265265
266266 Args:
267- x (Union[ MetaTensor, Tuple[MetaTensor]] ): Input tensor or a tuple of MetaTensors . If the input is a tuple,
267+ x (MetaTensor): Input tensor. If the input is a tuple,
268268 it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.
269269
270270 Returns:
@@ -280,9 +280,9 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
280280 - The predictions are converted to torch tensors, with added batch and channel dimensions.
281281 - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
282282 """
283- if isinstance (x , tuple ): # if batch is decollated (list of tensors)
284- properties_or_list_of_properties = []
285- image_or_list_of_images = []
283+ # if isinstance(x, tuple): # if batch is decollated (list of tensors)
284+ # properties_or_list_of_properties = []
285+ # image_or_list_of_images = []
286286
287287 # for img in x:
288288 # if isinstance(img, MetaTensor):
@@ -291,15 +291,16 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
291291 # else:
292292 # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
293293
294- else : # if batch is collated
295- if isinstance (x , MetaTensor ):
296- if "pixdim" in x .meta :
297- properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
298- else :
299- properties_or_list_of_properties = {"spacing" : [1.0 , 1.0 , 1.0 ]}
294+ #else: # if batch is collated
295+ if isinstance (x , MetaTensor ):
296+ if "pixdim" in x .meta :
297+ properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
300298 else :
301- raise TypeError ("Input must be a MetaTensor or a tuple of MetaTensors." )
302- image_or_list_of_images = x .cpu ().numpy ()[0 , :]
299+ properties_or_list_of_properties = {"spacing" : [1.0 , 1.0 , 1.0 ]}
300+ else :
301+ raise TypeError ("Input must be a MetaTensor or a tuple of MetaTensors." )
302+
303+ image_or_list_of_images = x .cpu ().numpy ()[0 , :]
303304
304305 # input_files should be a list of file paths, one per modality
305306 prediction_output = self .predictor .predict_from_list_of_npy_arrays (
@@ -318,10 +319,10 @@ def forward(self, x: Union[MetaTensor, tuple[MetaTensor]]):
318319 out_tensors .append (torch .from_numpy (np .expand_dims (np .expand_dims (out , 0 ), 0 )))
319320 out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
320321
321- if type (x ) is tuple :
322- return MetaTensor (out_tensor , meta = x [0 ].meta )
323- else :
324- return MetaTensor (out_tensor , meta = x .meta )
322+ # if type(x) is tuple:
323+ # return MetaTensor(out_tensor, meta=x[0].meta)
324+ # else:
325+ return MetaTensor (out_tensor , meta = x .meta )
325326
326327
327328def get_nnunet_monai_predictor (model_folder , model_name = "model.pt" ):
0 commit comments