1111from __future__ import annotations
1212
1313import os
14+ import shutil
15+ from pathlib import Path
1416
1517import numpy as np
1618import torch
1719from torch ._dynamo import OptimizedModule
1820from torch .backends import cudnn
1921
20- from pathlib import Path
21- import shutil
2222from monai .data .meta_tensor import MetaTensor
2323from monai .utils import optional_import
2424
2525join , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "join" )
2626load_json , _ = optional_import ("batchgenerators.utilities.file_and_folder_operations" , name = "load_json" )
2727
28- __all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "nnUNetMONAIModelWrapper " ]
28+ __all__ = ["get_nnunet_trainer" , "get_nnunet_monai_predictor" , "convert_nnunet_to_monai_bundle" , "ModelnnUNetWrapper " ]
2929
3030
3131def get_nnunet_trainer (
@@ -42,7 +42,7 @@ def get_nnunet_trainer(
4242 only_run_validation = False ,
4343 disable_checkpointing = False ,
4444 val_with_best = False ,
45- device = torch . device ( "cuda" ) ,
45+ device = "cuda" ,
4646 pretrained_model = None ,
4747):
4848 """
@@ -98,7 +98,7 @@ def get_nnunet_trainer(
9898 Whether to disable checkpointing. Default is False.
9999 val_with_best : bool, optional
100100 Whether to validate with the best model. Default is False.
101- device : torch.device , optional
101+ device : str , optional
102102 The device to be used for training. Default is 'cuda'.
103103 pretrained_model : str, optional
104104 Path to the pretrained model file.
@@ -130,7 +130,7 @@ def get_nnunet_trainer(
130130 trainer_class_name ,
131131 plans_identifier ,
132132 use_compressed_data ,
133- device = device ,
133+ device = torch . device ( device ) ,
134134 )
135135 if disable_checkpointing :
136136 nnunet_trainer .disable_checkpointing = disable_checkpointing
@@ -150,7 +150,7 @@ def get_nnunet_trainer(
150150 return nnunet_trainer
151151
152152
153- class nnUNetMONAIModelWrapper (torch .nn .Module ):
153+ class ModelnnUNetWrapper (torch .nn .Module ):
154154 """
155155 A wrapper class for nnUNet model integration with MONAI framework.
156156 The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
@@ -188,7 +188,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
188188
189189 from nnunetv2 .utilities .plans_handling .plans_handler import PlansManager
190190
191- ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
191+ # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
192192 dataset_json = load_json (join (model_training_output_dir , "dataset.json" ))
193193 plans = load_json (join (model_training_output_dir , "plans.json" ))
194194 plans_manager = PlansManager (plans )
@@ -253,17 +253,17 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
253253 ):
254254 print ("Using torch.compile" )
255255 predictor .network = torch .compile (self .network )
256- ## End Block
256+ # End Block
257257 self .network_weights = self .predictor .network
258258
259259 def forward (self , x ):
260260 if type (x ) is tuple : # if batch is decollated (list of tensors)
261261 input_files = [img .meta ["filename_or_obj" ][0 ] for img in x ]
262- else : # if batch is collated
262+ else : # if batch is collated
263263 input_files = x .meta ["filename_or_obj" ]
264264 if type (input_files ) is str :
265265 input_files = [input_files ]
266-
266+
267267 # input_files should be a list of file paths, one per modality
268268 prediction_output = self .predictor .predict_from_files (
269269 [input_files ],
@@ -277,11 +277,11 @@ def forward(self, x):
277277 part_id = 0 ,
278278 )
279279 # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
280-
280+
281281 out_tensors = []
282- for out in prediction_output : # Add batch and channel dimensions
282+ for out in prediction_output : # Add batch and channel dimensions
283283 out_tensors .append (torch .from_numpy (np .expand_dims (np .expand_dims (out , 0 ), 0 )))
284- out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
284+ out_tensor = torch .cat (out_tensors , 0 ) # Concatenate along batch dimension
285285
286286 if type (x ) is tuple :
287287 return MetaTensor (out_tensor , meta = x [0 ].meta )
@@ -338,7 +338,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
338338 allow_tqdm = True ,
339339 )
340340 # initializes the network architecture, loads the checkpoint
341- wrapper = nnUNetMONAIModelWrapper (predictor , model_folder , model_name )
341+ wrapper = ModelnnUNetWrapper (predictor , model_folder , model_name )
342342 return wrapper
343343
344344
@@ -376,29 +376,32 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
376376
377377 from nnunetv2 .utilities .dataset_name_id_conversion import maybe_convert_to_dataset_name
378378
379-
380379 dataset_name = maybe_convert_to_dataset_name (nnunet_config ["dataset_name_or_id" ])
381380 nnunet_model_folder = Path (os .environ ["nnUNet_results" ]).joinpath (
382- dataset_name ,
383- f" { nnunet_trainer } __ { nnunet_plans } __ { nnunet_configuration } " )
384-
385- nnunet_checkpoint_final = torch .load (Path (nnunet_model_folder ).joinpath (f"fold_{ fold } " ,"checkpoint_final.pth" ))
386- nnunet_checkpoint_best = torch .load (Path (nnunet_model_folder ).joinpath (f"fold_{ fold } " ,"checkpoint_best.pth" ))
381+ dataset_name , f" { nnunet_trainer } __ { nnunet_plans } __ { nnunet_configuration } "
382+ )
383+
384+ nnunet_checkpoint_final = torch .load (Path (nnunet_model_folder ).joinpath (f"fold_{ fold } " , "checkpoint_final.pth" ))
385+ nnunet_checkpoint_best = torch .load (Path (nnunet_model_folder ).joinpath (f"fold_{ fold } " , "checkpoint_best.pth" ))
387386
388387 nnunet_checkpoint = {}
389- nnunet_checkpoint [' inference_allowed_mirroring_axes' ] = nnunet_checkpoint_final [' inference_allowed_mirroring_axes' ]
390- nnunet_checkpoint [' init_args' ] = nnunet_checkpoint_final [' init_args' ]
391- nnunet_checkpoint [' trainer_name' ] = nnunet_checkpoint_final [' trainer_name' ]
388+ nnunet_checkpoint [" inference_allowed_mirroring_axes" ] = nnunet_checkpoint_final [" inference_allowed_mirroring_axes" ]
389+ nnunet_checkpoint [" init_args" ] = nnunet_checkpoint_final [" init_args" ]
390+ nnunet_checkpoint [" trainer_name" ] = nnunet_checkpoint_final [" trainer_name" ]
392391
393- torch .save (nnunet_checkpoint , Path (bundle_root_folder ).joinpath ("models" ,"nnunet_checkpoint.pth" ))
392+ torch .save (nnunet_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "nnunet_checkpoint.pth" ))
394393
395394 monai_last_checkpoint = {}
396- monai_last_checkpoint [' network_weights' ] = nnunet_checkpoint_final [' network_weights' ]
397- torch .save (monai_last_checkpoint , Path (bundle_root_folder ).joinpath ("models" ,"model.pt" ))
395+ monai_last_checkpoint [" network_weights" ] = nnunet_checkpoint_final [" network_weights" ]
396+ torch .save (monai_last_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "model.pt" ))
398397
399398 monai_best_checkpoint = {}
400- monai_best_checkpoint [' network_weights' ] = nnunet_checkpoint_best [' network_weights' ]
401- torch .save (monai_best_checkpoint , Path (bundle_root_folder ).joinpath ("models" ,"best_model.pt" ))
399+ monai_best_checkpoint [" network_weights" ] = nnunet_checkpoint_best [" network_weights" ]
400+ torch .save (monai_best_checkpoint , Path (bundle_root_folder ).joinpath ("models" , "best_model.pt" ))
402401
403- shutil .copy (Path (nnunet_model_folder ).joinpath ("plans.json" ),Path (bundle_root_folder ).joinpath ("models" ,"plans.json" ))
404- shutil .copy (Path (nnunet_model_folder ).joinpath ("dataset.json" ),Path (bundle_root_folder ).joinpath ("models" ,"dataset.json" ))
402+ shutil .copy (
403+ Path (nnunet_model_folder ).joinpath ("plans.json" ), Path (bundle_root_folder ).joinpath ("models" , "plans.json" )
404+ )
405+ shutil .copy (
406+ Path (nnunet_model_folder ).joinpath ("dataset.json" ), Path (bundle_root_folder ).joinpath ("models" , "dataset.json" )
407+ )
0 commit comments