Skip to content

Commit c04cc05

Browse files
Merge branch 'dev' into 8328-nnunet-bundle-integration
2 parents ed2360c + 47798af commit c04cc05

File tree

8 files changed

+2648
-5
lines changed

8 files changed

+2648
-5
lines changed

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def get_nnunet_trainer(
124124
fold,
125125
trainer_class_name,
126126
plans_identifier,
127-
use_compressed_data,
128127
device=torch.device(device),
129128
)
130129
if disable_checkpointing:
@@ -235,9 +234,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
235234
predictor.trainer_name = trainer_name # type: ignore
236235
predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes # type: ignore
237236
predictor.label_manager = plans_manager.get_label_manager(dataset_json) # type: ignore
238-
if ("nnUNet_compile" in os.environ.keys()) and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")):
239-
print("Using torch.compile")
240-
# End Block
237+
241238
self.network_weights = self.predictor.network # type: ignore
242239

243240
def forward(self, x: MetaTensor) -> MetaTensor:

monai/nvflare/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

monai/nvflare/json_generator.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
from __future__ import annotations
12+
13+
import json
14+
import os.path
15+
16+
from nvflare.apis.event_type import EventType
17+
from nvflare.apis.fl_context import FLContext
18+
from nvflare.widgets.widget import Widget
19+
20+
21+
class PrepareJsonGenerator(Widget):
22+
"""
23+
A widget class to prepare and generate a JSON file containing data preparation configurations.
24+
25+
Parameters
26+
----------
27+
results_dir : str, optional
28+
The directory where the results will be stored (default is "prepare").
29+
json_file_name : str, optional
30+
The name of the JSON file to be generated (default is "data_dict.json").
31+
32+
Methods
33+
-------
34+
handle_event(event_type: str, fl_ctx: FLContext)
35+
Handles events during the federated learning process. Clears the data preparation configuration
36+
at the start of a run and saves the configuration to a JSON file at the end of a run.
37+
"""
38+
39+
def __init__(self, results_dir="prepare", json_file_name="data_dict.json"):
40+
super(PrepareJsonGenerator, self).__init__()
41+
42+
self._results_dir = results_dir
43+
self._data_prepare_config = {}
44+
self._json_file_name = json_file_name
45+
46+
def handle_event(self, event_type: str, fl_ctx: FLContext):
47+
if event_type == EventType.START_RUN:
48+
self._data_prepare_config.clear()
49+
elif event_type == EventType.END_RUN:
50+
self._data_prepare_config = fl_ctx.get_prop("client_data_dict", None)
51+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
52+
data_prepare_res_dir = os.path.join(run_dir, self._results_dir)
53+
if not os.path.exists(data_prepare_res_dir):
54+
os.makedirs(data_prepare_res_dir)
55+
56+
res_file_path = os.path.join(data_prepare_res_dir, self._json_file_name)
57+
with open(res_file_path, "w") as f:
58+
json.dump(self._data_prepare_config, f)
59+
60+
61+
class nnUNetPackageReportJsonGenerator(Widget):
62+
"""
63+
A class to generate JSON reports for nnUNet package.
64+
65+
Parameters
66+
----------
67+
results_dir : str, optional
68+
Directory where the report will be saved (default is "package_report").
69+
json_file_name : str, optional
70+
Name of the JSON file to save the report (default is "package_report.json").
71+
72+
Methods
73+
-------
74+
handle_event(event_type: str, fl_ctx: FLContext)
75+
Handles events to clear the report at the start of a run and save the report at the end of a run.
76+
"""
77+
78+
def __init__(self, results_dir="package_report", json_file_name="package_report.json"):
79+
super(nnUNetPackageReportJsonGenerator, self).__init__()
80+
81+
self._results_dir = results_dir
82+
self._report = {}
83+
self._json_file_name = json_file_name
84+
85+
def handle_event(self, event_type: str, fl_ctx: FLContext):
86+
if event_type == EventType.START_RUN:
87+
self._report.clear()
88+
elif event_type == EventType.END_RUN:
89+
datasets = fl_ctx.get_prop("package_report", None)
90+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
91+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
92+
if not os.path.exists(cross_val_res_dir):
93+
os.makedirs(cross_val_res_dir)
94+
95+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
96+
with open(res_file_path, "w") as f:
97+
json.dump(datasets, f)
98+
99+
100+
class nnUNetPlansJsonGenerator(Widget):
101+
"""
102+
A class to generate JSON files for nnUNet plans.
103+
104+
Parameters
105+
----------
106+
results_dir : str, optional
107+
Directory where the preprocessing results will be stored (default is "nnUNet_preprocessing").
108+
json_file_name : str, optional
109+
Name of the JSON file to be generated (default is "nnUNetPlans.json").
110+
111+
Methods
112+
-------
113+
handle_event(event_type: str, fl_ctx: FLContext)
114+
Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves
115+
the plans to a JSON file at the end of a run.
116+
"""
117+
118+
def __init__(self, results_dir="nnUNet_preprocessing", json_file_name="nnUNetPlans.json"):
119+
120+
super(nnUNetPlansJsonGenerator, self).__init__()
121+
122+
self._results_dir = results_dir
123+
self._nnUNetPlans = {}
124+
self._json_file_name = json_file_name
125+
126+
def handle_event(self, event_type: str, fl_ctx: FLContext):
127+
if event_type == EventType.START_RUN:
128+
self._nnUNetPlans.clear()
129+
elif event_type == EventType.END_RUN:
130+
datasets = fl_ctx.get_prop("nnunet_plans", None)
131+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
132+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
133+
if not os.path.exists(cross_val_res_dir):
134+
os.makedirs(cross_val_res_dir)
135+
136+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
137+
with open(res_file_path, "w") as f:
138+
json.dump(datasets, f)
139+
140+
141+
class nnUNetValSummaryJsonGenerator(Widget):
142+
"""
143+
A widget to generate a JSON summary for nnUNet validation results.
144+
145+
Parameters
146+
----------
147+
results_dir : str, optional
148+
Directory where the nnUNet training results are stored (default is "nnUNet_train").
149+
json_file_name : str, optional
150+
Name of the JSON file to save the validation summary (default is "val_summary.json").
151+
152+
Methods
153+
-------
154+
handle_event(event_type: str, fl_ctx: FLContext)
155+
Handles events during the federated learning process. Clears the nnUNet plans at the start of a run and saves
156+
the validation summary to a JSON file at the end of a run.
157+
"""
158+
159+
def __init__(self, results_dir="nnUNet_train", json_file_name="val_summary.json"):
160+
161+
super(nnUNetValSummaryJsonGenerator, self).__init__()
162+
163+
self._results_dir = results_dir
164+
self._nnUNetPlans = {}
165+
self._json_file_name = json_file_name
166+
167+
def handle_event(self, event_type: str, fl_ctx: FLContext):
168+
if event_type == EventType.START_RUN:
169+
self._nnUNetPlans.clear()
170+
elif event_type == EventType.END_RUN:
171+
datasets = fl_ctx.get_prop("val_summary_dict", None)
172+
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
173+
cross_val_res_dir = os.path.join(run_dir, self._results_dir)
174+
if not os.path.exists(cross_val_res_dir):
175+
os.makedirs(cross_val_res_dir)
176+
177+
res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
178+
with open(res_file_path, "w") as f:
179+
json.dump(datasets, f)

0 commit comments

Comments
 (0)