Skip to content

Commit 76d85e1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 56da1e4 commit 76d85e1

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

generation/maisi/scripts/inference.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main():
6363
)
6464
parser.add_argument(
6565
"--version",
66-
default='maisi3d-rflow',
66+
default="maisi3d-rflow",
6767
type=str,
6868
help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']",
6969
)
@@ -97,7 +97,8 @@ def main():
9797
},
9898
{
9999
"path": "models/mask_generation_autoencoder.pt",
100-
"url": "https://developer.download.nvidia.com/assets/Clara/monai" "/tutorials/mask_generation_autoencoder.pt",
100+
"url": "https://developer.download.nvidia.com/assets/Clara/monai"
101+
"/tutorials/mask_generation_autoencoder.pt",
101102
},
102103
{
103104
"path": "models/mask_generation_diffusion_unet.pt",
@@ -114,7 +115,7 @@ def main():
114115
"/tutorials/all_masks_flexible_size_and_spacing_4000.zip",
115116
},
116117
]
117-
118+
118119
if maisi_version == "maisi3d-ddpm":
119120
files += [
120121
{
@@ -142,7 +143,8 @@ def main():
142143
},
143144
{
144145
"path": "models/controlnet_3d_rflow.pt",
145-
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/" "controlnet_rflow_epoch208.pt",
146+
"url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/"
147+
"controlnet_rflow_epoch208.pt",
146148
},
147149
{
148150
"path": "configs/candidate_masks_flexible_size_and_spacing_4000.json",
@@ -151,8 +153,10 @@ def main():
151153
},
152154
]
153155
else:
154-
raise ValueError(f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.")
155-
156+
raise ValueError(
157+
f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}."
158+
)
159+
156160
for file in files:
157161
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
158162
download_url(url=file["url"], filepath=file["path"])

generation/maisi/scripts/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def ldm_conditional_sample_one_image(
269269
noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
270270

271271
if isinstance(noise_scheduler, DDPMScheduler) and num_inference_steps < noise_scheduler.num_train_timesteps:
272-
warnings.warn(
272+
warnings.warn(
273273
"**************************************************************\n"
274274
"* WARNING: Image noise_scheduler is a DDPMScheduler.\n"
275275
"* We expect num_inference_steps = noise_scheduler.num_train_timesteps"
@@ -829,7 +829,7 @@ def select_mask(self, candidate_mask_files, num_img):
829829
selected_mask_files = []
830830
random.shuffle(candidate_mask_files)
831831

832-
for n in range(num_img*self.max_try_time):
832+
for n in range(num_img * self.max_try_time):
833833
mask_file = candidate_mask_files[n % len(candidate_mask_files)]
834834
selected_mask_files.append({"mask_file": mask_file, "if_aug": True})
835835
return selected_mask_files

0 commit comments

Comments
 (0)