@@ -173,15 +173,25 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
173173 urllib .request .urlretrieve (tos_url , cache_path )
174174
175175 checkpoint_path = configs .load_checkpoint_path
176- checkpoint_path = os .path .join (
177- code_directory , f"release_data/checkpoint/model_{ model_version } .pt"
178- )
179176
180177 if not opexists (checkpoint_path ):
178+ checkpoint_path = os .path .join (
179+ code_directory , f"release_data/checkpoint/model_{ model_version } .pt"
180+ )
181181 os .makedirs (os .path .dirname (checkpoint_path ), exist_ok = True )
182182 tos_url = URL [f"model_{ model_version } " ]
183183 logger .info (f"Downloading model checkpoint from\n { tos_url } ..." )
184184 urllib .request .urlretrieve (tos_url , checkpoint_path )
185+ try :
186+ ckpt = torch .load (checkpoint_path )
187+ del ckpt
188+ except :
189+ os .remove (checkpoint_path )
190+ raise RuntimeError (
191+ "Download model checkpoint failed, please download by yourself with "
192+ f"wget { tos_url } -O { checkpoint_path } "
193+ )
194+ configs .load_checkpoint_path = checkpoint_path
185195
186196
187197def update_inference_configs (configs : Any , N_token : int ):
@@ -299,13 +309,5 @@ def run() -> None:
299309 main (configs )
300310
301311
302- def run_default () -> None :
303- inference_configs ["load_checkpoint_path" ] = "/af3-dev/release_model/model_v0.2.0.pt"
304- configs_base ["model" ]["N_cycle" ] = 10
305- configs_base ["sample_diffusion" ]["N_sample" ] = 5
306- configs_base ["sample_diffusion" ]["N_step" ] = 200
307- run ()
308-
309-
310312if __name__ == "__main__" :
311313 run ()
0 commit comments