Skip to content

Commit 9b0a18f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 45ff2e0 commit 9b0a18f

File tree

4 files changed

+34
-23
lines changed

4 files changed

+34
-23
lines changed

generation/maisi/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 🚨🚨🚨 THIS FOLDER IS DEPRECATED 🚨🚨🚨
1+
# 🚨🚨🚨 THIS FOLDER IS DEPRECATED 🚨🚨🚨
22
# 👉 Please switch to: [https://github.com/NVIDIA-Medtech/NV-Generate-CTMR/tree/main](https://github.com/NVIDIA-Medtech/NV-Generate-CTMR/tree/main)
33

44
# Medical AI for Synthetic Imaging (MAISI)

generation/maisi/maisi_inference_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@
205205
" os.makedirs(directory, exist_ok=True)\n",
206206
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
207207
"\n",
208-
"download_model_data(maisi_version,root_dir)\n",
208+
"download_model_data(maisi_version, root_dir)\n",
209209
"\n",
210210
"for file in files:\n",
211211
" file[\"path\"] = file[\"path\"] if \"datasets/\" not in file[\"path\"] else os.path.join(root_dir, file[\"path\"])\n",

generation/maisi/scripts/download_model_data.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from huggingface_hub import snapshot_download
77
from typing import List, Dict, Optional
88

9+
910
def fetch_to_hf_path_cmd(
1011
items: List[Dict[str, str]],
11-
root_dir: str = "./", # staging dir for CLI output
12+
root_dir: str = "./", # staging dir for CLI output
1213
revision: str = "main",
1314
overwrite: bool = False,
14-
token: Optional[str] = None, # or rely on env HUGGINGFACE_HUB_TOKEN
15+
token: Optional[str] = None, # or rely on env HUGGINGFACE_HUB_TOKEN
1516
) -> list[str]:
1617
"""
1718
items: list of {"repo_id": "...", "filename": "path/in/repo.ext", "path": "local/target.ext"}
@@ -25,11 +26,11 @@ def fetch_to_hf_path_cmd(
2526
env = os.environ.copy()
2627
if token:
2728
env["HUGGINGFACE_HUB_TOKEN"] = token
28-
env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") # safer in Jupyter
29-
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") # show CLI progress in terminal
29+
env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") # safer in Jupyter
30+
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") # show CLI progress in terminal
3031

3132
for it in items:
32-
repo_id = it["repo_id"]
33+
repo_id = it["repo_id"]
3334
repo_file = it["filename"]
3435
dst = Path(it["path"])
3536
dst.parent.mkdir(parents=True, exist_ok=True)
@@ -40,11 +41,15 @@ def fetch_to_hf_path_cmd(
4041

4142
# Build command (no shell=True; no quoting issues)
4243
cmd = [
43-
"huggingface-cli", "download",
44+
"huggingface-cli",
45+
"download",
4446
repo_id,
45-
"--include", repo_file,
46-
"--revision", revision,
47-
"--local-dir", str(root),
47+
"--include",
48+
repo_file,
49+
"--revision",
50+
revision,
51+
"--local-dir",
52+
str(root),
4853
]
4954
# Run
5055
subprocess.run(cmd, check=True, env=env)
@@ -69,15 +74,14 @@ def fetch_to_hf_path_cmd(
6974
return saved
7075

7176

72-
73-
def download_model_data(generate_version,root_dir, model_only=False):
77+
def download_model_data(generate_version, root_dir, model_only=False):
7478
# TODO: remove the `files` after the files are uploaded to the NGC
7579
if generate_version == "ddpm-ct" or generate_version == "rflow-ct":
7680
files = [
7781
{
7882
"path": "models/autoencoder_v1.pt",
7983
"repo_id": "nvidia/NV-Generate-CT",
80-
"filename":"models/autoencoder_v1.pt",
84+
"filename": "models/autoencoder_v1.pt",
8185
},
8286
{
8387
"path": "models/mask_generation_autoencoder.pt",
@@ -88,7 +92,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
8892
"path": "models/mask_generation_diffusion_unet.pt",
8993
"repo_id": "nvidia/NV-Generate-CT",
9094
"filename": "models/mask_generation_diffusion_unet.pt",
91-
}]
95+
},
96+
]
9297
if not model_only:
9398
files += [
9499
{
@@ -113,10 +118,12 @@ def download_model_data(generate_version,root_dir, model_only=False):
113118
"path": "models/diff_unet_3d_rflow-mr.pt",
114119
"repo_id": "nvidia/NV-Generate-MR",
115120
"filename": "models/diff_unet_3d_rflow-mr.pt",
116-
}
121+
},
117122
]
118123
else:
119-
raise ValueError(f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}.")
124+
raise ValueError(
125+
f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}."
126+
)
120127
if generate_version == "ddpm-ct":
121128
files += [
122129
{
@@ -128,7 +135,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
128135
"path": "models/controlnet_3d_ddpm-ct.pt",
129136
"repo_id": "nvidia/NV-Generate-CT",
130137
"filename": "models/controlnet_3d_ddpm-ct.pt",
131-
}]
138+
},
139+
]
132140
if not model_only:
133141
files += [
134142
{
@@ -148,7 +156,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
148156
"path": "models/controlnet_3d_rflow-ct.pt",
149157
"repo_id": "nvidia/NV-Generate-CT",
150158
"filename": "models/controlnet_3d_rflow-ct.pt",
151-
}]
159+
},
160+
]
152161
if not model_only:
153162
files += [
154163
{
@@ -157,11 +166,11 @@ def download_model_data(generate_version,root_dir, model_only=False):
157166
"filename": "datasets/candidate_masks_flexible_size_and_spacing_4000.json",
158167
},
159168
]
160-
169+
161170
for file in files:
162171
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
163172
if "repo_id" in file.keys():
164-
path = fetch_to_hf_path_cmd([file],root_dir=root_dir, revision="main")
173+
path = fetch_to_hf_path_cmd([file], root_dir=root_dir, revision="main")
165174
print("saved to:", path)
166175
else:
167176
download_url(url=file["url"], filepath=file["path"])
@@ -180,7 +189,9 @@ def download_model_data(generate_version,root_dir, model_only=False):
180189
type=str,
181190
default="./",
182191
)
183-
parser.add_argument("--model_only", dest="model_only", action="store_true", help="Download model only, not any dataset")
192+
parser.add_argument(
193+
"--model_only", dest="model_only", action="store_true", help="Download model only, not any dataset"
194+
)
184195

185196
args = parser.parse_args()
186197
download_model_data(args.version, args.root_dir, args.model_only)

generation/maisi/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def main():
8989
root_dir = tempfile.mkdtemp() if directory is None else directory
9090
print(root_dir)
9191

92-
download_model_data(maisi_version,root_dir)
92+
download_model_data(maisi_version, root_dir)
9393

9494
# ## Read in environment setting, including data directory, model directory, and output directory
9595
# The information for data directory, model directory, and output directory are saved in ./configs/environment.json

0 commit comments

Comments
 (0)