Skip to content

Commit 4ca028a

Browse files
Enhance nnUNet bundle: add nnUNetPredictor import and update type hints in ModelnnUNetWrapper
1 parent 18d5a4c commit 4ca028a

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
2626
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
2727
nnUNetTrainer, _ = optional_import("nnunetv2.training.nnUNetTrainer", name="nnUNetTrainer")
28+
nnUNetPredictor, _ = optional_import("nnunetv2.inference.predict_from_raw_data", name="nnUNetPredictor")
2829

2930
__all__ = [
3031
"get_nnunet_trainer",
@@ -171,9 +172,7 @@ class ModelnnUNetWrapper(torch.nn.Module):
171172
restoring network architecture, and setting up the predictor for inference.
172173
"""
173174

174-
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
175-
176-
def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"):
175+
def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
177176
super().__init__()
178177
self.predictor = predictor
179178

@@ -228,18 +227,18 @@ def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], m
228227
enable_deep_supervision=False,
229228
)
230229

231-
predictor.plans_manager = plans_manager
232-
predictor.configuration_manager = configuration_manager
233-
predictor.list_of_parameters = parameters
234-
predictor.network = network
235-
predictor.dataset_json = dataset_json
236-
predictor.trainer_name = trainer_name
237-
predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes
238-
predictor.label_manager = plans_manager.get_label_manager(dataset_json)
230+
predictor.plans_manager = plans_manager # type: ignore
231+
predictor.configuration_manager = configuration_manager # type: ignore
232+
predictor.list_of_parameters = parameters # type: ignore
233+
predictor.network = network # type: ignore
234+
predictor.dataset_json = dataset_json # type: ignore
235+
predictor.trainer_name = trainer_name # type: ignore
236+
predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore
237+
predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore
239238
if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")):
240239
print("Using torch.compile")
241240
# End Block
242-
self.network_weights = self.predictor.network
241+
self.network_weights = self.predictor.network # type: ignore
243242

244243
def forward(self, x: MetaTensor) -> MetaTensor:
245244
"""
@@ -282,7 +281,7 @@ def forward(self, x: MetaTensor) -> MetaTensor:
282281
image_or_list_of_images = x.cpu().numpy()[0, :]
283282

284283
# input_files should be a list of file paths, one per modality
285-
prediction_output = self.predictor.predict_from_list_of_npy_arrays(
284+
prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore
286285
image_or_list_of_images,
287286
None,
288287
properties_or_list_of_properties,

0 commit comments

Comments
 (0)