1313import os
1414import shutil
1515from pathlib import Path
16+ from typing import Optional , Union
1617
1718import numpy as np
1819import torch
2122from monai .data .meta_tensor import MetaTensor
2223from monai .utils import optional_import
2324
24- from typing import Union , Optional
2525join , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "join" )
2626load_json , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "load_json" )
2727
28- __all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "convert_nnunet_to_monai_bundle" , "convert_monai_bundle_to_nnunet" ,"ModelnnUNetWrapper" ]
28+ __all__ = [
29+ "get_nnunet_trainer" ,
30+ "get_nnunet_monai_predictor" ,
31+ "get_network_from_nnunet_plans" ,
32+ "convert_nnunet_to_monai_bundle" ,
33+ "convert_monai_bundle_to_nnunet" ,
34+ "ModelnnUNetWrapper" ,
35+ ]
2936
3037
3138def get_nnunet_trainer (
@@ -107,7 +114,6 @@ def get_nnunet_trainer(
107114 )
108115 raise e
109116
110-
111117 from nnunetv2 .run .run_training import get_trainer_from_args , maybe_load_checkpoint
112118
113119 nnunet_trainer = get_trainer_from_args (
@@ -178,7 +184,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
178184 plans_manager = PlansManager (plans )
179185
180186 parameters = []
181-
187+
182188 checkpoint = torch .load (
183189 join (Path (model_training_output_dir ).parent , "nnunet_checkpoint.pth" ), map_location = torch .device ("cpu" )
184190 )
@@ -190,9 +196,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
190196 else None
191197 )
192198 if Path (model_training_output_dir ).joinpath (model_name ).is_file ():
193- monai_checkpoint = torch .load (
194- join (model_training_output_dir , model_name ), map_location = torch .device ("cpu" )
195- )
199+ monai_checkpoint = torch .load (join (model_training_output_dir , model_name ), map_location = torch .device ("cpu" ))
196200 if "network_weights" in monai_checkpoint .keys ():
197201 parameters .append (monai_checkpoint ["network_weights" ])
198202 else :
@@ -230,10 +234,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
230234 predictor .trainer_name = trainer_name
231235 predictor .allowed_mirroring_axes = inference_allowed_mirroring_axes
232236 predictor .label_manager = plans_manager .get_label_manager (dataset_json )
233- if (
234- ("nnUNet_compile" in os .environ .keys ())
235- and (os .environ ["nnUNet_compile" ].lower () in ("true" , "1" , "t" ))
236- ):
237+ if ("nnUNet_compile" in os .environ .keys ()) and (os .environ ["nnUNet_compile" ].lower () in ("true" , "1" , "t" )):
237238 print ("Using torch.compile" )
238239 # End Block
239240 self .network_weights = self .predictor .network
@@ -265,7 +266,11 @@ def forward(self, x: MetaTensor) -> MetaTensor:
265266 if "pixdim" in x .meta :
266267 properties_or_list_of_properties = {"spacing" : x .meta ["pixdim" ][0 ][1 :4 ].numpy ().tolist ()}
267268 elif "affine" in x .meta :
268- spacing = [abs (x .meta ['affine' ][0 ][0 ].item ()), abs (x .meta ['affine' ][1 ][1 ].item ()), abs (x .meta ['affine' ][2 ][2 ].item ())]
269+ spacing = [
270+ abs (x .meta ["affine" ][0 ][0 ].item ()),
271+ abs (x .meta ["affine" ][1 ][1 ].item ()),
272+ abs (x .meta ["affine" ][2 ][2 ].item ()),
273+ ]
269274 properties_or_list_of_properties = {"spacing" : spacing }
270275 else :
271276 properties_or_list_of_properties = {"spacing" : [1.0 , 1.0 , 1.0 ]}
@@ -348,9 +353,7 @@ def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt")
348353 return wrapper
349354
350355
351- def convert_nnunet_to_monai_bundle (
352- nnunet_config : dict , bundle_root_folder : str , fold : int = 0
353- ) -> None :
356+ def convert_nnunet_to_monai_bundle (nnunet_config : dict , bundle_root_folder : str , fold : int = 0 ) -> None :
354357 """
355358 Convert nnUNet model checkpoints and configuration to MONAI bundle format.
356359
@@ -421,14 +424,14 @@ def convert_nnunet_to_monai_bundle(
421424
422425
423426def get_network_from_nnunet_plans (
424- plans_file : str ,
425- dataset_file : str ,
426- configuration : str ,
427- model_ckpt : Optional [str ] = None ,
428- model_key_in_ckpt : str = "model"
427+ plans_file : str ,
428+ dataset_file : str ,
429+ configuration : str ,
430+ model_ckpt : Optional [str ] = None ,
431+ model_key_in_ckpt : str = "model" ,
429432) -> torch .nn .Module :
430433 """
431- Load and initialize a neural network based on nnUNet plans and configuration.
434+ Load and initialize a nnUNet network based on nnUNet plans and configuration.
432435
433436 Parameters
434437 ----------
@@ -481,11 +484,7 @@ def get_network_from_nnunet_plans(
481484 return network
482485
483486
484- def convert_monai_bundle_to_nnunet (
485- nnunet_config : dict ,
486- bundle_root_folder : str ,
487- fold : int = 0
488- ) -> None :
487+ def convert_monai_bundle_to_nnunet (nnunet_config : dict , bundle_root_folder : str , fold : int = 0 ) -> None :
489488 """
490489 Convert a MONAI bundle to nnU-Net format.
491490
@@ -520,11 +519,7 @@ def convert_monai_bundle_to_nnunet(
520519 from nnunetv2 .utilities .dataset_name_id_conversion import maybe_convert_to_dataset_name
521520
522521 def subfiles (
523- folder : str ,
524- join : bool = True ,
525- prefix : Optional [str ] = None ,
526- suffix : Optional [str ] = None ,
527- sort : bool = True
522+ folder : str , join : bool = True , prefix : Optional [str ] = None , suffix : Optional [str ] = None , sort : bool = True
528523 ) -> list [str ]:
529524 if join :
530525 l = os .path .join # noqa: E741
@@ -562,7 +557,9 @@ def subfiles(
562557
563558 epochs .sort ()
564559 final_epoch : int = epochs [- 1 ]
565- monai_last_checkpoint : dict = torch .load (f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_epoch={ final_epoch } .pt" )
560+ monai_last_checkpoint : dict = torch .load (
561+ f"{ bundle_root_folder } /models/fold_{ fold } /checkpoint_epoch={ final_epoch } .pt"
562+ )
566563
567564 best_checkpoints : list [str ] = subfiles (
568565 Path (bundle_root_folder ).joinpath ("models" , f"fold_{ fold } " ),
0 commit comments