Skip to content

Commit 2d5278d

Browse files
committed
Merge branch 'maisi' of https://github.com/Can-Zhao/tutorials into maisi
2 parents 7883209 + 9b0a18f commit 2d5278d

File tree

4 files changed

+34
-23
lines changed

4 files changed

+34
-23
lines changed

generation/maisi/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 🚨🚨🚨 THIS FOLDER IS DEPRECATED 🚨🚨🚨
1+
# 🚨🚨🚨 THIS FOLDER IS DEPRECATED 🚨🚨🚨
22
# 👉 Please switch to: [https://github.com/NVIDIA-Medtech/NV-Generate-CTMR/tree/main](https://github.com/NVIDIA-Medtech/NV-Generate-CTMR/tree/main)
33

44
# Medical AI for Synthetic Imaging (MAISI)

generation/maisi/maisi_inference_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@
205205
" os.makedirs(directory, exist_ok=True)\n",
206206
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
207207
"\n",
208-
"download_model_data(maisi_version,root_dir)\n",
208+
"download_model_data(maisi_version, root_dir)\n",
209209
"\n",
210210
"for file in files:\n",
211211
" file[\"path\"] = file[\"path\"] if \"datasets/\" not in file[\"path\"] else os.path.join(root_dir, file[\"path\"])\n",

generation/maisi/scripts/download_model_data.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from huggingface_hub import snapshot_download
1818
from typing import List, Dict, Optional
1919

20+
2021
def fetch_to_hf_path_cmd(
2122
items: List[Dict[str, str]],
22-
root_dir: str = "./", # staging dir for CLI output
23+
root_dir: str = "./", # staging dir for CLI output
2324
revision: str = "main",
2425
overwrite: bool = False,
25-
token: Optional[str] = None, # or rely on env HUGGINGFACE_HUB_TOKEN
26+
token: Optional[str] = None, # or rely on env HUGGINGFACE_HUB_TOKEN
2627
) -> list[str]:
2728
"""
2829
items: list of {"repo_id": "...", "filename": "path/in/repo.ext", "path": "local/target.ext"}
@@ -36,11 +37,11 @@ def fetch_to_hf_path_cmd(
3637
env = os.environ.copy()
3738
if token:
3839
env["HUGGINGFACE_HUB_TOKEN"] = token
39-
env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") # safer in Jupyter
40-
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") # show CLI progress in terminal
40+
env.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") # safer in Jupyter
41+
env.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0") # show CLI progress in terminal
4142

4243
for it in items:
43-
repo_id = it["repo_id"]
44+
repo_id = it["repo_id"]
4445
repo_file = it["filename"]
4546
dst = Path(it["path"])
4647
dst.parent.mkdir(parents=True, exist_ok=True)
@@ -51,11 +52,15 @@ def fetch_to_hf_path_cmd(
5152

5253
# Build command (no shell=True; no quoting issues)
5354
cmd = [
54-
"huggingface-cli", "download",
55+
"huggingface-cli",
56+
"download",
5557
repo_id,
56-
"--include", repo_file,
57-
"--revision", revision,
58-
"--local-dir", str(root),
58+
"--include",
59+
repo_file,
60+
"--revision",
61+
revision,
62+
"--local-dir",
63+
str(root),
5964
]
6065
# Run
6166
subprocess.run(cmd, check=True, env=env)
@@ -80,15 +85,14 @@ def fetch_to_hf_path_cmd(
8085
return saved
8186

8287

83-
84-
def download_model_data(generate_version,root_dir, model_only=False):
88+
def download_model_data(generate_version, root_dir, model_only=False):
8589
# TODO: remove the `files` after the files are uploaded to the NGC
8690
if generate_version == "ddpm-ct" or generate_version == "rflow-ct":
8791
files = [
8892
{
8993
"path": "models/autoencoder_v1.pt",
9094
"repo_id": "nvidia/NV-Generate-CT",
91-
"filename":"models/autoencoder_v1.pt",
95+
"filename": "models/autoencoder_v1.pt",
9296
},
9397
{
9498
"path": "models/mask_generation_autoencoder.pt",
@@ -99,7 +103,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
99103
"path": "models/mask_generation_diffusion_unet.pt",
100104
"repo_id": "nvidia/NV-Generate-CT",
101105
"filename": "models/mask_generation_diffusion_unet.pt",
102-
}]
106+
},
107+
]
103108
if not model_only:
104109
files += [
105110
{
@@ -124,10 +129,12 @@ def download_model_data(generate_version,root_dir, model_only=False):
124129
"path": "models/diff_unet_3d_rflow-mr.pt",
125130
"repo_id": "nvidia/NV-Generate-MR",
126131
"filename": "models/diff_unet_3d_rflow-mr.pt",
127-
}
132+
},
128133
]
129134
else:
130-
raise ValueError(f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}.")
135+
raise ValueError(
136+
f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got {generate_version}."
137+
)
131138
if generate_version == "ddpm-ct":
132139
files += [
133140
{
@@ -139,7 +146,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
139146
"path": "models/controlnet_3d_ddpm-ct.pt",
140147
"repo_id": "nvidia/NV-Generate-CT",
141148
"filename": "models/controlnet_3d_ddpm-ct.pt",
142-
}]
149+
},
150+
]
143151
if not model_only:
144152
files += [
145153
{
@@ -159,7 +167,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
159167
"path": "models/controlnet_3d_rflow-ct.pt",
160168
"repo_id": "nvidia/NV-Generate-CT",
161169
"filename": "models/controlnet_3d_rflow-ct.pt",
162-
}]
170+
},
171+
]
163172
if not model_only:
164173
files += [
165174
{
@@ -168,11 +177,11 @@ def download_model_data(generate_version,root_dir, model_only=False):
168177
"filename": "datasets/candidate_masks_flexible_size_and_spacing_4000.json",
169178
},
170179
]
171-
180+
172181
for file in files:
173182
file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"])
174183
if "repo_id" in file.keys():
175-
path = fetch_to_hf_path_cmd([file],root_dir=root_dir, revision="main")
184+
path = fetch_to_hf_path_cmd([file], root_dir=root_dir, revision="main")
176185
print("saved to:", path)
177186
else:
178187
download_url(url=file["url"], filepath=file["path"])
@@ -191,7 +200,9 @@ def download_model_data(generate_version,root_dir, model_only=False):
191200
type=str,
192201
default="./",
193202
)
194-
parser.add_argument("--model_only", dest="model_only", action="store_true", help="Download model only, not any dataset")
203+
parser.add_argument(
204+
"--model_only", dest="model_only", action="store_true", help="Download model only, not any dataset"
205+
)
195206

196207
args = parser.parse_args()
197208
download_model_data(args.version, args.root_dir, args.model_only)

generation/maisi/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def main():
8989
root_dir = tempfile.mkdtemp() if directory is None else directory
9090
print(root_dir)
9191

92-
download_model_data(maisi_version,root_dir)
92+
download_model_data(maisi_version, root_dir)
9393

9494
# ## Read in environment setting, including data directory, model directory, and output directory
9595
# The information for data directory, model directory, and output directory are saved in ./configs/environment.json

0 commit comments

Comments
 (0)