1717from huggingface_hub import snapshot_download
1818from typing import List , Dict , Optional
1919
20+
2021def fetch_to_hf_path_cmd (
2122 items : List [Dict [str , str ]],
22- root_dir : str = "./" , # staging dir for CLI output
23+ root_dir : str = "./" , # staging dir for CLI output
2324 revision : str = "main" ,
2425 overwrite : bool = False ,
25- token : Optional [str ] = None , # or rely on env HUGGINGFACE_HUB_TOKEN
26+ token : Optional [str ] = None , # or rely on env HUGGINGFACE_HUB_TOKEN
2627) -> list [str ]:
2728 """
2829 items: list of {"repo_id": "...", "filename": "path/in/repo.ext", "path": "local/target.ext"}
@@ -36,11 +37,11 @@ def fetch_to_hf_path_cmd(
3637 env = os .environ .copy ()
3738 if token :
3839 env ["HUGGINGFACE_HUB_TOKEN" ] = token
39- env .setdefault ("HF_HUB_ENABLE_HF_TRANSFER" , "0" ) # safer in Jupyter
40- env .setdefault ("HF_HUB_DISABLE_PROGRESS_BARS" , "0" ) # show CLI progress in terminal
40+ env .setdefault ("HF_HUB_ENABLE_HF_TRANSFER" , "0" ) # safer in Jupyter
41+ env .setdefault ("HF_HUB_DISABLE_PROGRESS_BARS" , "0" ) # show CLI progress in terminal
4142
4243 for it in items :
43- repo_id = it ["repo_id" ]
44+ repo_id = it ["repo_id" ]
4445 repo_file = it ["filename" ]
4546 dst = Path (it ["path" ])
4647 dst .parent .mkdir (parents = True , exist_ok = True )
@@ -51,11 +52,15 @@ def fetch_to_hf_path_cmd(
5152
5253 # Build command (no shell=True; no quoting issues)
5354 cmd = [
54- "huggingface-cli" , "download" ,
55+ "huggingface-cli" ,
56+ "download" ,
5557 repo_id ,
56- "--include" , repo_file ,
57- "--revision" , revision ,
58- "--local-dir" , str (root ),
58+ "--include" ,
59+ repo_file ,
60+ "--revision" ,
61+ revision ,
62+ "--local-dir" ,
63+ str (root ),
5964 ]
6065 # Run
6166 subprocess .run (cmd , check = True , env = env )
@@ -80,15 +85,14 @@ def fetch_to_hf_path_cmd(
8085 return saved
8186
8287
83-
84- def download_model_data (generate_version ,root_dir , model_only = False ):
88+ def download_model_data (generate_version , root_dir , model_only = False ):
8589 # TODO: remove the `files` after the files are uploaded to the NGC
8690 if generate_version == "ddpm-ct" or generate_version == "rflow-ct" :
8791 files = [
8892 {
8993 "path" : "models/autoencoder_v1.pt" ,
9094 "repo_id" : "nvidia/NV-Generate-CT" ,
91- "filename" :"models/autoencoder_v1.pt" ,
95+ "filename" : "models/autoencoder_v1.pt" ,
9296 },
9397 {
9498 "path" : "models/mask_generation_autoencoder.pt" ,
@@ -99,7 +103,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
99103 "path" : "models/mask_generation_diffusion_unet.pt" ,
100104 "repo_id" : "nvidia/NV-Generate-CT" ,
101105 "filename" : "models/mask_generation_diffusion_unet.pt" ,
102- }]
106+ },
107+ ]
103108 if not model_only :
104109 files += [
105110 {
@@ -124,10 +129,12 @@ def download_model_data(generate_version,root_dir, model_only=False):
124129 "path" : "models/diff_unet_3d_rflow-mr.pt" ,
125130 "repo_id" : "nvidia/NV-Generate-MR" ,
126131 "filename" : "models/diff_unet_3d_rflow-mr.pt" ,
127- }
132+ },
128133 ]
129134 else :
130- raise ValueError (f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got { generate_version } ." )
135+ raise ValueError (
136+ f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got { generate_version } ."
137+ )
131138 if generate_version == "ddpm-ct" :
132139 files += [
133140 {
@@ -139,7 +146,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
139146 "path" : "models/controlnet_3d_ddpm-ct.pt" ,
140147 "repo_id" : "nvidia/NV-Generate-CT" ,
141148 "filename" : "models/controlnet_3d_ddpm-ct.pt" ,
142- }]
149+ },
150+ ]
143151 if not model_only :
144152 files += [
145153 {
@@ -159,7 +167,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
159167 "path" : "models/controlnet_3d_rflow-ct.pt" ,
160168 "repo_id" : "nvidia/NV-Generate-CT" ,
161169 "filename" : "models/controlnet_3d_rflow-ct.pt" ,
162- }]
170+ },
171+ ]
163172 if not model_only :
164173 files += [
165174 {
@@ -168,11 +177,11 @@ def download_model_data(generate_version,root_dir, model_only=False):
168177 "filename" : "datasets/candidate_masks_flexible_size_and_spacing_4000.json" ,
169178 },
170179 ]
171-
180+
172181 for file in files :
173182 file ["path" ] = file ["path" ] if "datasets/" not in file ["path" ] else os .path .join (root_dir , file ["path" ])
174183 if "repo_id" in file .keys ():
175- path = fetch_to_hf_path_cmd ([file ],root_dir = root_dir , revision = "main" )
184+ path = fetch_to_hf_path_cmd ([file ], root_dir = root_dir , revision = "main" )
176185 print ("saved to:" , path )
177186 else :
178187 download_url (url = file ["url" ], filepath = file ["path" ])
@@ -191,7 +200,9 @@ def download_model_data(generate_version,root_dir, model_only=False):
191200 type = str ,
192201 default = "./" ,
193202 )
194- parser .add_argument ("--model_only" , dest = "model_only" , action = "store_true" , help = "Download model only, not any dataset" )
203+ parser .add_argument (
204+ "--model_only" , dest = "model_only" , action = "store_true" , help = "Download model only, not any dataset"
205+ )
195206
196207 args = parser .parse_args ()
197208 download_model_data (args .version , args .root_dir , args .model_only )
0 commit comments