Skip to content

Commit 60185d1

Browse files
Add new functions to nnunet_bundle for converting between MONAI and nnU-Net formats
1 parent 1c41164 commit 60185d1

File tree

2 files changed

+31
-32
lines changed

2 files changed

+31
-32
lines changed

docs/source/apps.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,5 @@ FastMRIReader
287287
.. autofunction:: monai.apps.nnunet.get_nnunet_trainer
288288
.. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor
289289
.. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle
290+
.. autofunction:: monai.apps.nnunet.convert_monai_bundle_to_nnunet
291+
.. autofunction:: monai.apps.nnunet.get_network_from_nnunet_plans

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import shutil
1515
from pathlib import Path
16+
from typing import Optional, Union
1617

1718
import numpy as np
1819
import torch
@@ -21,11 +22,17 @@
2122
from monai.data.meta_tensor import MetaTensor
2223
from monai.utils import optional_import
2324

24-
from typing import Union, Optional
2525
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
2626
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
2727

28-
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "convert_monai_bundle_to_nnunet","ModelnnUNetWrapper"]
28+
__all__ = [
29+
"get_nnunet_trainer",
30+
"get_nnunet_monai_predictor",
31+
"get_network_from_nnunet_plans",
32+
"convert_nnunet_to_monai_bundle",
33+
"convert_monai_bundle_to_nnunet",
34+
"ModelnnUNetWrapper",
35+
]
2936

3037

3138
def get_nnunet_trainer(
@@ -107,7 +114,6 @@ def get_nnunet_trainer(
107114
)
108115
raise e
109116

110-
111117
from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
112118

113119
nnunet_trainer = get_trainer_from_args(
@@ -178,7 +184,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
178184
plans_manager = PlansManager(plans)
179185

180186
parameters = []
181-
187+
182188
checkpoint = torch.load(
183189
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
184190
)
@@ -190,9 +196,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
190196
else None
191197
)
192198
if Path(model_training_output_dir).joinpath(model_name).is_file():
193-
monai_checkpoint = torch.load(
194-
join(model_training_output_dir, model_name), map_location=torch.device("cpu")
195-
)
199+
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
196200
if "network_weights" in monai_checkpoint.keys():
197201
parameters.append(monai_checkpoint["network_weights"])
198202
else:
@@ -230,10 +234,7 @@ def __init__(self, predictor: object, model_folder: str, model_name: str = "mode
230234
predictor.trainer_name = trainer_name
231235
predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes
232236
predictor.label_manager = plans_manager.get_label_manager(dataset_json)
233-
if (
234-
("nnUNet_compile" in os.environ.keys())
235-
and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t"))
236-
):
237+
if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")):
237238
print("Using torch.compile")
238239
# End Block
239240
self.network_weights = self.predictor.network
@@ -265,7 +266,11 @@ def forward(self, x: MetaTensor) -> MetaTensor:
265266
if "pixdim" in x.meta:
266267
properties_or_list_of_properties = {"spacing": x.meta["pixdim"][0][1:4].numpy().tolist()}
267268
elif "affine" in x.meta:
268-
spacing = [abs(x.meta['affine'][0][0].item()), abs(x.meta['affine'][1][1].item()), abs(x.meta['affine'][2][2].item())]
269+
spacing = [
270+
abs(x.meta["affine"][0][0].item()),
271+
abs(x.meta["affine"][1][1].item()),
272+
abs(x.meta["affine"][2][2].item()),
273+
]
269274
properties_or_list_of_properties = {"spacing": spacing}
270275
else:
271276
properties_or_list_of_properties = {"spacing": [1.0, 1.0, 1.0]}
@@ -348,9 +353,7 @@ def get_nnunet_monai_predictor(model_folder: str, model_name: str = "model.pt")
348353
return wrapper
349354

350355

351-
def convert_nnunet_to_monai_bundle(
352-
nnunet_config: dict, bundle_root_folder: str, fold: int = 0
353-
) -> None:
356+
def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
354357
"""
355358
Convert nnUNet model checkpoints and configuration to MONAI bundle format.
356359
@@ -421,14 +424,14 @@ def convert_nnunet_to_monai_bundle(
421424

422425

423426
def get_network_from_nnunet_plans(
424-
plans_file: str,
425-
dataset_file: str,
426-
configuration: str,
427-
model_ckpt: Optional[str] = None,
428-
model_key_in_ckpt: str = "model"
427+
plans_file: str,
428+
dataset_file: str,
429+
configuration: str,
430+
model_ckpt: Optional[str] = None,
431+
model_key_in_ckpt: str = "model",
429432
) -> torch.nn.Module:
430433
"""
431-
Load and initialize a neural network based on nnUNet plans and configuration.
434+
Load and initialize a nnUNet network based on nnUNet plans and configuration.
432435
433436
Parameters
434437
----------
@@ -481,11 +484,7 @@ def get_network_from_nnunet_plans(
481484
return network
482485

483486

484-
def convert_monai_bundle_to_nnunet(
485-
nnunet_config: dict,
486-
bundle_root_folder: str,
487-
fold: int = 0
488-
) -> None:
487+
def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:
489488
"""
490489
Convert a MONAI bundle to nnU-Net format.
491490
@@ -520,11 +519,7 @@ def convert_monai_bundle_to_nnunet(
520519
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
521520

522521
def subfiles(
523-
folder: str,
524-
join: bool = True,
525-
prefix: Optional[str] = None,
526-
suffix: Optional[str] = None,
527-
sort: bool = True
522+
folder: str, join: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True
528523
) -> list[str]:
529524
if join:
530525
l = os.path.join # noqa: E741
@@ -562,7 +557,9 @@ def subfiles(
562557

563558
epochs.sort()
564559
final_epoch: int = epochs[-1]
565-
monai_last_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt")
560+
monai_last_checkpoint: dict = torch.load(
561+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
562+
)
566563

567564
best_checkpoints: list[str] = subfiles(
568565
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"),

0 commit comments

Comments
 (0)