Skip to content

Commit cfd5636

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 151177d commit cfd5636

File tree

8 files changed

+149
-143
lines changed

8 files changed

+149
-143
lines changed

generation/maisi/maisi_inference_tutorial.ipynb

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@
210210
" \"path\": \"models/autoencoder_epoch273.pt\",\n",
211211
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials\"\n",
212212
" \"/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt\",\n",
213-
" }, \n",
213+
" },\n",
214214
" {\n",
215215
" \"path\": \"models/mask_generation_autoencoder.pt\",\n",
216216
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\" \"/tutorials/mask_generation_autoencoder.pt\",\n",
@@ -219,12 +219,11 @@
219219
" \"path\": \"models/mask_generation_diffusion_unet.pt\",\n",
220220
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
221221
" \"/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt\",\n",
222-
" }, \n",
222+
" },\n",
223223
" {\n",
224224
" \"path\": \"configs/all_anatomy_size_condtions.json\",\n",
225225
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json\",\n",
226226
" },\n",
227-
" \n",
228227
"]\n",
229228
"\n",
230229
"if maisi_version == \"maisi3d-ddpm\":\n",
@@ -259,15 +258,14 @@
259258
" },\n",
260259
" {\n",
261260
" \"path\": \"models/controlnet_3d_rflow.pt\",\n",
262-
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/\"\n",
263-
" \"controlnet_rflow_epoch208.pt\",\n",
261+
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/\" \"controlnet_rflow_epoch208.pt\",\n",
264262
" },\n",
265263
" {\n",
266264
" \"path\": \"configs/candidate_masks_flexible_size_and_spacing_4000.json\",\n",
267265
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
268266
" \"/tutorials/candidate_masks_flexible_size_and_spacing_4000.json\",\n",
269267
" },\n",
270-
" {\n",
268+
" {\n",
271269
" \"path\": \"datasets/all_masks_flexible_size_and_spacing_4000.zip\",\n",
272270
" \"url\": \"https://developer.download.nvidia.com/assets/Clara/monai\"\n",
273271
" \"/tutorials/all_masks_flexible_size_and_spacing_4000.zip\",\n",
@@ -545,7 +543,7 @@
545543
" mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,\n",
546544
" random_seed=args.random_seed,\n",
547545
" autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size,\n",
548-
" autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap\n",
546+
" autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap,\n",
549547
")"
550548
]
551549
},

generation/maisi/scripts/augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,4 @@ def augmentation(pt_nda, output_size, random_seed):
370370
print("augmenting body")
371371
pt_nda = augmentation_body(pt_nda, random_seed)
372372

373-
return pt_nda
373+
return pt_nda

generation/maisi/scripts/diff_model_infer.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def run_inference(
136136
np.ndarray: Generated synthetic image data.
137137
"""
138138
include_body_region = unet.include_top_region_index_input
139-
include_modality = (unet.num_class_embeds is not None)
140-
139+
include_modality = unet.num_class_embeds is not None
140+
141141
noise = torch.randn(
142142
(
143143
1,
@@ -178,18 +178,22 @@ def run_inference(
178178
"timesteps": torch.Tensor((t,)).to(device),
179179
"spacing_tensor": spacing_tensor,
180180
}
181-
181+
182182
# Add extra arguments if include_body_region is True
183183
if include_body_region:
184-
unet_inputs.update({
185-
"top_region_index_tensor": top_region_index_tensor,
186-
"bottom_region_index_tensor": bottom_region_index_tensor
187-
})
184+
unet_inputs.update(
185+
{
186+
"top_region_index_tensor": top_region_index_tensor,
187+
"bottom_region_index_tensor": bottom_region_index_tensor,
188+
}
189+
)
188190

189191
if include_modality:
190-
unet_inputs.update({
191-
"class_labels": modality_tensor,
192-
})
192+
unet_inputs.update(
193+
{
194+
"class_labels": modality_tensor,
195+
}
196+
)
193197
model_output = unet(**unet_inputs)
194198
if not isinstance(noise_scheduler, RFlowScheduler):
195199
image, _ = noise_scheduler.step(model_output, t, image) # type: ignore
@@ -241,9 +245,7 @@ def save_image(
241245

242246

243247
@torch.inference_mode()
244-
def diff_model_infer(
245-
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int
246-
) -> None:
248+
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
247249
"""
248250
Main function to run the diffusion model inference.
249251
@@ -301,7 +303,7 @@ def diff_model_infer(
301303
modality_tensor,
302304
output_size,
303305
divisor,
304-
logger
306+
logger,
305307
)
306308

307309
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
@@ -352,6 +354,4 @@ def diff_model_infer(
352354
)
353355

354356
args = parser.parse_args()
355-
diff_model_infer(
356-
args.env_config, args.model_config, args.model_def, args.num_gpus
357-
)
357+
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)

generation/maisi/scripts/diff_model_train.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def train_one_epoch(
226226
torch.Tensor: Training loss for the epoch.
227227
"""
228228
include_body_region = unet.include_top_region_index_input
229-
include_modality = (unet.num_class_embeds is not None)
230-
229+
include_modality = unet.num_class_embeds is not None
230+
231231
if local_rank == 0:
232232
current_lr = optimizer.param_groups[0]["lr"]
233233
logger.info(f"Epoch {epoch + 1}, lr {current_lr}.")
@@ -248,7 +248,7 @@ def train_one_epoch(
248248
bottom_region_index_tensor = train_data["bottom_region_index"].to(device)
249249
# We trained with only CT in this version
250250
if include_modality:
251-
modality_tensor = torch.ones((len(images),),dtype=torch.long).to(device)
251+
modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device)
252252
spacing_tensor = train_data["spacing"].to(device)
253253

254254
optimizer.zero_grad(set_to_none=True)
@@ -268,18 +268,22 @@ def train_one_epoch(
268268
"x": noisy_latent,
269269
"timesteps": timesteps,
270270
"spacing_tensor": spacing_tensor,
271-
}
271+
}
272272
# Add extra arguments if include_body_region is True
273273
if include_body_region:
274-
unet_inputs.update({
275-
"top_region_index_tensor": top_region_index_tensor,
276-
"bottom_region_index_tensor": bottom_region_index_tensor
277-
})
274+
unet_inputs.update(
275+
{
276+
"top_region_index_tensor": top_region_index_tensor,
277+
"bottom_region_index_tensor": bottom_region_index_tensor,
278+
}
279+
)
278280
if include_modality:
279-
unet_inputs.update({
280-
"class_labels": modality_tensor,
281-
})
282-
model_output = unet(**unet_inputs)
281+
unet_inputs.update(
282+
{
283+
"class_labels": modality_tensor,
284+
}
285+
)
286+
model_output = unet(**unet_inputs)
283287

284288
if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON:
285289
# predict noise
@@ -359,11 +363,7 @@ def save_checkpoint(
359363

360364

361365
def diff_model_train(
362-
env_config_path: str,
363-
model_config_path: str,
364-
model_def_path: str,
365-
num_gpus: int,
366-
amp: bool = True
366+
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
367367
) -> None:
368368
"""
369369
Main function to train a diffusion model.
@@ -424,8 +424,6 @@ def diff_model_train(
424424
include_body_region=include_body_region,
425425
)
426426

427-
428-
429427
scale_factor = calculate_scale_factor(train_loader, device, logger)
430428
optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"])
431429

@@ -455,7 +453,7 @@ def diff_model_train(
455453
device,
456454
logger,
457455
local_rank,
458-
amp=amp
456+
amp=amp,
459457
)
460458

461459
loss_torch = loss_torch.tolist()
@@ -498,6 +496,4 @@ def diff_model_train(
498496
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
499497

500498
args = parser.parse_args()
501-
diff_model_train(
502-
args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp
503-
)
499+
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)

generation/maisi/scripts/infer_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def main():
111111
# define diffusion Model
112112
unet = define_instance(args, "diffusion_unet_def").to(device)
113113
include_body_region = unet.include_top_region_index_input
114-
include_modality = (unet.num_class_embeds is not None)
115-
114+
include_modality = unet.num_class_embeds is not None
115+
116116
# load trained diffusion model
117117
if args.trained_diffusion_path is not None:
118118
if not os.path.exists(args.trained_diffusion_path):

generation/maisi/scripts/quality_check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def is_outlier(statistics, image_data, label_data, label_int_dict):
113113
high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs
114114

115115
if label_name == "bone":
116-
high_thresh = 1000.
116+
high_thresh = 1000.0
117117

118118
# Retrieve the corresponding label integers
119119
labels = label_int_dict.get(label_name, [])
@@ -146,4 +146,4 @@ def is_outlier(statistics, image_data, label_data, label_int_dict):
146146
"high_thresh": high_thresh,
147147
}
148148

149-
return outlier_results
149+
return outlier_results

0 commit comments

Comments
 (0)