1818import torch
1919from torch .backends import cudnn
2020
21- from typing import Union , Tuple
21+ from typing import Union
2222from monai .data .meta_tensor import MetaTensor
2323from monai .utils import optional_import
24- from nnunetv2 .training .logging .nnunet_logger import nnUNetLogger
2524
2625join , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "join" )
2726load_json , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "load_json" )
@@ -258,7 +257,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
258257 # End Block
259258 self .network_weights = self .predictor .network
260259
261- def forward (self , x : Union [MetaTensor , Tuple [MetaTensor ]]):
260+ def forward (self , x : Union [MetaTensor , tuple [MetaTensor ]]):
262261 """
263262 Forward pass for the nnUNet model.
264263
@@ -291,7 +290,7 @@ def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]):
291290 # image_or_list_of_images.append(img.cpu().numpy()[0,:])
292291 #else:
293292 # raise TypeError("Input must be a MetaTensor or a tuple of MetaTensors.")
294-
293+
295294 else : # if batch is collated
296295 if isinstance (x , MetaTensor ):
297296 if 'pixdim' in x .meta :
@@ -307,7 +306,7 @@ def forward(self, x: Union[MetaTensor, Tuple[MetaTensor]]):
307306 image_or_list_of_images ,
308307 None ,
309308 properties_or_list_of_properties ,
310- truncated_ofname = None ,
309+ truncated_ofname = None ,
311310 save_probabilities = False ,
312311 num_processes = 2 ,
313312 num_processes_segmentation_export = 2
@@ -441,8 +440,8 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
441440 shutil .copy (
442441 Path (nnunet_model_folder ).joinpath ("plans.json" ), Path (bundle_root_folder ).joinpath ("models" , "plans.json" )
443442 )
444-
443+
445444 if not os .path .exists (os .path .join (bundle_root_folder , "models" , "dataset.json" )):
446445 shutil .copy (
447446 Path (nnunet_model_folder ).joinpath ("dataset.json" ), Path (bundle_root_folder ).joinpath ("models" , "dataset.json" )
448- )
447+ )
0 commit comments