66from huggingface_hub import snapshot_download
77from typing import List , Dict , Optional
88
9+
910def fetch_to_hf_path_cmd (
1011 items : List [Dict [str , str ]],
11- root_dir : str = "./" , # staging dir for CLI output
12+ root_dir : str = "./" , # staging dir for CLI output
1213 revision : str = "main" ,
1314 overwrite : bool = False ,
14- token : Optional [str ] = None , # or rely on env HUGGINGFACE_HUB_TOKEN
15+ token : Optional [str ] = None , # or rely on env HUGGINGFACE_HUB_TOKEN
1516) -> list [str ]:
1617 """
1718 items: list of {"repo_id": "...", "filename": "path/in/repo.ext", "path": "local/target.ext"}
@@ -25,11 +26,11 @@ def fetch_to_hf_path_cmd(
2526 env = os .environ .copy ()
2627 if token :
2728 env ["HUGGINGFACE_HUB_TOKEN" ] = token
28- env .setdefault ("HF_HUB_ENABLE_HF_TRANSFER" , "0" ) # safer in Jupyter
29- env .setdefault ("HF_HUB_DISABLE_PROGRESS_BARS" , "0" ) # show CLI progress in terminal
29+ env .setdefault ("HF_HUB_ENABLE_HF_TRANSFER" , "0" ) # safer in Jupyter
30+ env .setdefault ("HF_HUB_DISABLE_PROGRESS_BARS" , "0" ) # show CLI progress in terminal
3031
3132 for it in items :
32- repo_id = it ["repo_id" ]
33+ repo_id = it ["repo_id" ]
3334 repo_file = it ["filename" ]
3435 dst = Path (it ["path" ])
3536 dst .parent .mkdir (parents = True , exist_ok = True )
@@ -40,11 +41,15 @@ def fetch_to_hf_path_cmd(
4041
4142 # Build command (no shell=True; no quoting issues)
4243 cmd = [
43- "huggingface-cli" , "download" ,
44+ "huggingface-cli" ,
45+ "download" ,
4446 repo_id ,
45- "--include" , repo_file ,
46- "--revision" , revision ,
47- "--local-dir" , str (root ),
47+ "--include" ,
48+ repo_file ,
49+ "--revision" ,
50+ revision ,
51+ "--local-dir" ,
52+ str (root ),
4853 ]
4954 # Run
5055 subprocess .run (cmd , check = True , env = env )
@@ -69,15 +74,14 @@ def fetch_to_hf_path_cmd(
6974 return saved
7075
7176
72-
73- def download_model_data (generate_version ,root_dir , model_only = False ):
77+ def download_model_data (generate_version , root_dir , model_only = False ):
7478 # TODO: remove the `files` after the files are uploaded to the NGC
7579 if generate_version == "ddpm-ct" or generate_version == "rflow-ct" :
7680 files = [
7781 {
7882 "path" : "models/autoencoder_v1.pt" ,
7983 "repo_id" : "nvidia/NV-Generate-CT" ,
80- "filename" :"models/autoencoder_v1.pt" ,
84+ "filename" : "models/autoencoder_v1.pt" ,
8185 },
8286 {
8387 "path" : "models/mask_generation_autoencoder.pt" ,
@@ -88,7 +92,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
8892 "path" : "models/mask_generation_diffusion_unet.pt" ,
8993 "repo_id" : "nvidia/NV-Generate-CT" ,
9094 "filename" : "models/mask_generation_diffusion_unet.pt" ,
91- }]
95+ },
96+ ]
9297 if not model_only :
9398 files += [
9499 {
@@ -113,10 +118,12 @@ def download_model_data(generate_version,root_dir, model_only=False):
113118 "path" : "models/diff_unet_3d_rflow-mr.pt" ,
114119 "repo_id" : "nvidia/NV-Generate-MR" ,
115120 "filename" : "models/diff_unet_3d_rflow-mr.pt" ,
116- }
121+ },
117122 ]
118123 else :
119- raise ValueError (f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got { generate_version } ." )
124+ raise ValueError (
125+ f"generate_version has to be chosen from ['ddpm-ct', 'rflow-ct', 'rflow-mr'], yet got { generate_version } ."
126+ )
120127 if generate_version == "ddpm-ct" :
121128 files += [
122129 {
@@ -128,7 +135,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
128135 "path" : "models/controlnet_3d_ddpm-ct.pt" ,
129136 "repo_id" : "nvidia/NV-Generate-CT" ,
130137 "filename" : "models/controlnet_3d_ddpm-ct.pt" ,
131- }]
138+ },
139+ ]
132140 if not model_only :
133141 files += [
134142 {
@@ -148,7 +156,8 @@ def download_model_data(generate_version,root_dir, model_only=False):
148156 "path" : "models/controlnet_3d_rflow-ct.pt" ,
149157 "repo_id" : "nvidia/NV-Generate-CT" ,
150158 "filename" : "models/controlnet_3d_rflow-ct.pt" ,
151- }]
159+ },
160+ ]
152161 if not model_only :
153162 files += [
154163 {
@@ -157,11 +166,11 @@ def download_model_data(generate_version,root_dir, model_only=False):
157166 "filename" : "datasets/candidate_masks_flexible_size_and_spacing_4000.json" ,
158167 },
159168 ]
160-
169+
161170 for file in files :
162171 file ["path" ] = file ["path" ] if "datasets/" not in file ["path" ] else os .path .join (root_dir , file ["path" ])
163172 if "repo_id" in file .keys ():
164- path = fetch_to_hf_path_cmd ([file ],root_dir = root_dir , revision = "main" )
173+ path = fetch_to_hf_path_cmd ([file ], root_dir = root_dir , revision = "main" )
165174 print ("saved to:" , path )
166175 else :
167176 download_url (url = file ["url" ], filepath = file ["path" ])
@@ -180,7 +189,9 @@ def download_model_data(generate_version,root_dir, model_only=False):
180189 type = str ,
181190 default = "./" ,
182191 )
183- parser .add_argument ("--model_only" , dest = "model_only" , action = "store_true" , help = "Download model only, not any dataset" )
192+ parser .add_argument (
193+ "--model_only" , dest = "model_only" , action = "store_true" , help = "Download model only, not any dataset"
194+ )
184195
185196 args = parser .parse_args ()
186197 download_model_data (args .version , args .root_dir , args .model_only )
0 commit comments