Skip to content

Commit 8e4a66c

Browse files
Code reformatting
1 parent ccdc76c commit 8e4a66c

File tree

3 files changed

+74
-76
lines changed

3 files changed

+74
-76
lines changed

monai/bundle/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
1515
from .config_parser import ConfigParser
16-
from .nnunet import get_nnunet_monai_predictor, get_nnunet_trainer, nnUNetMONAIModelWrapper
16+
from .nnunet import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer
1717
from .properties import InferProperties, MetaProperties, TrainProperties
1818
from .reference_resolver import ReferenceResolver
1919
from .scripts import (

monai/bundle/nnunet.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@
1111
from __future__ import annotations
1212

1313
import os
14+
import shutil
15+
from pathlib import Path
1416

1517
import numpy as np
1618
import torch
1719
from torch._dynamo import OptimizedModule
1820
from torch.backends import cudnn
1921

20-
from pathlib import Path
21-
import shutil
2222
from monai.data.meta_tensor import MetaTensor
2323
from monai.utils import optional_import
2424

2525
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
2626
load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json")
2727

28-
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "nnUNetMONAIModelWrapper"]
28+
__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"]
2929

3030

3131
def get_nnunet_trainer(
@@ -42,7 +42,7 @@ def get_nnunet_trainer(
4242
only_run_validation=False,
4343
disable_checkpointing=False,
4444
val_with_best=False,
45-
device=torch.device("cuda"),
45+
device="cuda",
4646
pretrained_model=None,
4747
):
4848
"""
@@ -98,7 +98,7 @@ def get_nnunet_trainer(
9898
Whether to disable checkpointing. Default is False.
9999
val_with_best : bool, optional
100100
Whether to validate with the best model. Default is False.
101-
device : torch.device, optional
101+
device : str, optional
102102
The device to be used for training. Default is 'cuda'.
103103
pretrained_model : str, optional
104104
Path to the pretrained model file.
@@ -130,7 +130,7 @@ def get_nnunet_trainer(
130130
trainer_class_name,
131131
plans_identifier,
132132
use_compressed_data,
133-
device=device,
133+
device=torch.device(device),
134134
)
135135
if disable_checkpointing:
136136
nnunet_trainer.disable_checkpointing = disable_checkpointing
@@ -150,7 +150,7 @@ def get_nnunet_trainer(
150150
return nnunet_trainer
151151

152152

153-
class nnUNetMONAIModelWrapper(torch.nn.Module):
153+
class ModelnnUNetWrapper(torch.nn.Module):
154154
"""
155155
A wrapper class for nnUNet model integration with MONAI framework.
156156
The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.
@@ -188,7 +188,7 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
188188

189189
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
190190

191-
## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
191+
# Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
192192
dataset_json = load_json(join(model_training_output_dir, "dataset.json"))
193193
plans = load_json(join(model_training_output_dir, "plans.json"))
194194
plans_manager = PlansManager(plans)
@@ -253,17 +253,17 @@ def __init__(self, predictor, model_folder, model_name="model.pt"):
253253
):
254254
print("Using torch.compile")
255255
predictor.network = torch.compile(self.network)
256-
## End Block
256+
# End Block
257257
self.network_weights = self.predictor.network
258258

259259
def forward(self, x):
260260
if type(x) is tuple: # if batch is decollated (list of tensors)
261261
input_files = [img.meta["filename_or_obj"][0] for img in x]
262-
else: # if batch is collated
262+
else: # if batch is collated
263263
input_files = x.meta["filename_or_obj"]
264264
if type(input_files) is str:
265265
input_files = [input_files]
266-
266+
267267
# input_files should be a list of file paths, one per modality
268268
prediction_output = self.predictor.predict_from_files(
269269
[input_files],
@@ -277,11 +277,11 @@ def forward(self, x):
277277
part_id=0,
278278
)
279279
# prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax
280-
280+
281281
out_tensors = []
282-
for out in prediction_output: # Add batch and channel dimensions
282+
for out in prediction_output: # Add batch and channel dimensions
283283
out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))
284-
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
284+
out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension
285285

286286
if type(x) is tuple:
287287
return MetaTensor(out_tensor, meta=x[0].meta)
@@ -338,7 +338,7 @@ def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):
338338
allow_tqdm=True,
339339
)
340340
# initializes the network architecture, loads the checkpoint
341-
wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)
341+
wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)
342342
return wrapper
343343

344344

@@ -376,29 +376,32 @@ def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0):
376376

377377
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
378378

379-
380379
dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"])
381380
nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath(
382-
dataset_name,
383-
f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}")
384-
385-
nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_final.pth"))
386-
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}","checkpoint_best.pth"))
381+
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
382+
)
383+
384+
nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
385+
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
387386

388387
nnunet_checkpoint = {}
389-
nnunet_checkpoint['inference_allowed_mirroring_axes'] = nnunet_checkpoint_final['inference_allowed_mirroring_axes']
390-
nnunet_checkpoint['init_args'] = nnunet_checkpoint_final['init_args']
391-
nnunet_checkpoint['trainer_name'] = nnunet_checkpoint_final['trainer_name']
388+
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
389+
nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"]
390+
nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"]
392391

393-
torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models","nnunet_checkpoint.pth"))
392+
torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth"))
394393

395394
monai_last_checkpoint = {}
396-
monai_last_checkpoint['network_weights'] = nnunet_checkpoint_final['network_weights']
397-
torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models","model.pt"))
395+
monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"]
396+
torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt"))
398397

399398
monai_best_checkpoint = {}
400-
monai_best_checkpoint['network_weights'] = nnunet_checkpoint_best['network_weights']
401-
torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models","best_model.pt"))
399+
monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"]
400+
torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt"))
402401

403-
shutil.copy(Path(nnunet_model_folder).joinpath("plans.json"),Path(bundle_root_folder).joinpath("models","plans.json"))
404-
shutil.copy(Path(nnunet_model_folder).joinpath("dataset.json"),Path(bundle_root_folder).joinpath("models","dataset.json"))
402+
shutil.copy(
403+
Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json")
404+
)
405+
shutil.copy(
406+
Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json")
407+
)

tests/test_integration_nnunet_bundle.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414
import os
1515
import tempfile
1616
import unittest
17+
from pathlib import Path
1718

1819
import nibabel as nib
1920
import numpy as np
2021

2122
from monai.apps.nnunet import nnUNetV2Runner
22-
from monai.bundle.nnunet import get_nnunet_trainer, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor
23-
from monai.transforms import LoadImaged, SaveImaged, Transposed, EnsureChannelFirstd, Compose, Decollated
24-
from monai.data import DataLoader, Dataset
25-
from pathlib import Path
2623
from monai.bundle.config_parser import ConfigParser
27-
from monai.data import create_test_image_3d
24+
from monai.bundle.nnunet import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer
25+
from monai.data import DataLoader, Dataset, create_test_image_3d
26+
from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed
2827
from monai.utils import optional_import
2928
from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
3029

@@ -86,57 +85,53 @@ def setUp(self) -> None:
8685
self.test_path = test_path
8786

8887
@skip_if_no_cuda
89-
def test_nnunetBundle_get_trainer(self) -> None:
90-
runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch")
88+
def test_nnunet_bundle(self) -> None:
89+
runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch",work_dir=self.test_path)
9190
with skip_if_downloading_fails():
9291
runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)
9392

94-
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id=runner.dataset_name, fold=0,configuration="3d_fullres")
95-
93+
nnunet_trainer = get_nnunet_trainer(
94+
dataset_name_or_id=runner.dataset_name, fold=0, configuration="3d_fullres"
95+
)
96+
9697
print("Max Epochs: ", nnunet_trainer.num_epochs)
9798
print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch)
98-
print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)['data'].shape)
99-
print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)['data'].shape)
99+
print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)["data"].shape)
100+
print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)["data"].shape)
100101
print("Network: ", nnunet_trainer.network)
101102
print("Optimizer: ", nnunet_trainer.optimizer)
102103
print("Loss Function: ", nnunet_trainer.loss)
103104
print("LR Scheduler: ", nnunet_trainer.lr_scheduler)
104105
print("Device: ", nnunet_trainer.device)
105-
runner.train("3d_fullres")
106-
@skip_if_no_cuda
107-
def test_nnunetBundle_convert_bundle(self) -> None:
108-
109-
110-
nnunet_config = {
111-
"dataset_name_or_id": "001",
112-
"nnunet_trainer": "nnUNetTrainer_1epoch",
113-
}
114-
self.bundle_root = os.path.join("bundle_root")
115-
116-
Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True)
117-
convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0)
118-
119-
120-
def test_nnunetBundle_predict_from_bundle(self) -> None:
121-
data_transforms = Compose([
122-
LoadImaged(keys="image"),
123-
EnsureChannelFirstd(keys="image"),
124-
])
125-
dataset = Dataset(data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}],
126-
transform=data_transforms)
127-
data_loader = DataLoader(dataset, batch_size=1)
128-
input = next(iter(data_loader))
129-
130-
predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models"))
131-
pred_batch = predictor(input["image"])
132-
Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True)
133-
134-
post_processing_transforms = Compose([
106+
runner.train_single_model("3d_fullres", fold="0")
107+
108+
nnunet_config = {"dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch"}
109+
self.bundle_root = os.path.join("bundle_root")
110+
111+
Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True)
112+
convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0)
113+
114+
data_transforms = Compose([LoadImaged(keys="image"), EnsureChannelFirstd(keys="image")])
115+
dataset = Dataset(
116+
data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}], transform=data_transforms
117+
)
118+
data_loader = DataLoader(dataset, batch_size=1)
119+
input = next(iter(data_loader))
120+
121+
predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models"))
122+
pred_batch = predictor(input["image"])
123+
Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True)
124+
125+
post_processing_transforms = Compose(
126+
[
135127
Decollated(keys=None, detach=True),
136128
Transposed(keys="pred", indices=[0, 3, 2, 1]),
137-
SaveImaged(keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred"),
138-
])
139-
post_processing_transforms({"pred": pred_batch})
129+
SaveImaged(
130+
keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred"
131+
),
132+
]
133+
)
134+
post_processing_transforms({"pred": pred_batch})
140135

141136
def tearDown(self) -> None:
142137
self.test_dir.cleanup()

0 commit comments

Comments
 (0)