|
25 | 25 | join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") |
26 | 26 | load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") |
27 | 27 | nnUNetTrainer, _ = optional_import("nnunetv2.training.nnUNetTrainer", name="nnUNetTrainer") |
| 28 | +nnUNetPredictor, _ = optional_import("nnunetv2.inference.predict_from_raw_data", name="nnUNetPredictor") |
28 | 29 |
|
29 | 30 | __all__ = [ |
30 | 31 | "get_nnunet_trainer", |
@@ -171,9 +172,7 @@ class ModelnnUNetWrapper(torch.nn.Module): |
171 | 172 | restoring network architecture, and setting up the predictor for inference. |
172 | 173 | """ |
173 | 174 |
|
174 | | - from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
175 | | - |
176 | | - def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): |
| 175 | + def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore |
177 | 176 | super().__init__() |
178 | 177 | self.predictor = predictor |
179 | 178 |
|
@@ -228,18 +227,18 @@ def __init__(self, predictor: nnUNetPredictor, model_folder: Union[str, Path], m |
228 | 227 | enable_deep_supervision=False, |
229 | 228 | ) |
230 | 229 |
|
231 | | - predictor.plans_manager = plans_manager |
232 | | - predictor.configuration_manager = configuration_manager |
233 | | - predictor.list_of_parameters = parameters |
234 | | - predictor.network = network |
235 | | - predictor.dataset_json = dataset_json |
236 | | - predictor.trainer_name = trainer_name |
237 | | - predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes |
238 | | - predictor.label_manager = plans_manager.get_label_manager(dataset_json) |
| 230 | + predictor.plans_manager = plans_manager # type: ignore |
| 231 | + predictor.configuration_manager = configuration_manager # type: ignore |
| 232 | + predictor.list_of_parameters = parameters # type: ignore |
| 233 | + predictor.network = network # type: ignore |
| 234 | + predictor.dataset_json = dataset_json # type: ignore |
| 235 | + predictor.trainer_name = trainer_name # type: ignore |
| 236 | + predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore |
| 237 | + predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore |
239 | 238 | if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")): |
240 | 239 | print("Using torch.compile") |
241 | 240 | # End Block |
242 | | - self.network_weights = self.predictor.network |
| 241 | + self.network_weights = self.predictor.network # type: ignore |
243 | 242 |
|
244 | 243 | def forward(self, x: MetaTensor) -> MetaTensor: |
245 | 244 | """ |
@@ -282,7 +281,7 @@ def forward(self, x: MetaTensor) -> MetaTensor: |
282 | 281 | image_or_list_of_images = x.cpu().numpy()[0, :] |
283 | 282 |
|
284 | 283 | # input_files should be a list of file paths, one per modality |
285 | | - prediction_output = self.predictor.predict_from_list_of_npy_arrays( |
| 284 | + prediction_output = self.predictor.predict_from_list_of_npy_arrays( # type: ignore |
286 | 285 | image_or_list_of_images, |
287 | 286 | None, |
288 | 287 | properties_or_list_of_properties, |
|
0 commit comments