Skip to content

Commit ca0170c

Browse files
Merge branch 'fix_readme_shibo' into 'main'
fix ckpt cache bug See merge request molecule/protenix!45
2 parents e0421d4 + 7038523 commit ca0170c

3 files changed

Lines changed: 14 additions & 13 deletions

File tree

runner/batch_inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def generate_infer_jsons(
162162

163163

164164
def get_default_runner(seeds: Optional[list] = None) -> InferenceRunner:
165-
inference_configs["load_checkpoint_path"] = "/af3-dev/release_model/model_v0.2.0.pt"
166165
configs_base["use_deepspeed_evo_attention"] = (
167166
os.environ.get("USE_DEEPSPEED_EVO_ATTTENTION", False) == "true"
168167
)

runner/inference.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,25 @@ def download_infercence_cache(configs: Any, model_version: str = "v0.2.0") -> No
173173
urllib.request.urlretrieve(tos_url, cache_path)
174174

175175
checkpoint_path = configs.load_checkpoint_path
176-
checkpoint_path = os.path.join(
177-
code_directory, f"release_data/checkpoint/model_{model_version}.pt"
178-
)
179176

180177
if not opexists(checkpoint_path):
178+
checkpoint_path = os.path.join(
179+
code_directory, f"release_data/checkpoint/model_{model_version}.pt"
180+
)
181181
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
182182
tos_url = URL[f"model_{model_version}"]
183183
logger.info(f"Downloading model checkpoint from\n {tos_url}...")
184184
urllib.request.urlretrieve(tos_url, checkpoint_path)
185+
try:
186+
ckpt = torch.load(checkpoint_path)
187+
del ckpt
188+
except:
189+
os.remove(checkpoint_path)
190+
raise RuntimeError(
191+
"Download model checkpoint failed, please download by yourself with "
192+
f"wget {tos_url} -O {checkpoint_path}"
193+
)
194+
configs.load_checkpoint_path = checkpoint_path
185195

186196

187197
def update_inference_configs(configs: Any, N_token: int):
@@ -299,13 +309,5 @@ def run() -> None:
299309
main(configs)
300310

301311

302-
def run_default() -> None:
303-
inference_configs["load_checkpoint_path"] = "/af3-dev/release_model/model_v0.2.0.pt"
304-
configs_base["model"]["N_cycle"] = 10
305-
configs_base["sample_diffusion"]["N_sample"] = 5
306-
configs_base["sample_diffusion"]["N_step"] = 200
307-
run()
308-
309-
310312
if __name__ == "__main__":
311313
run()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
setup(
2121
name="protenix",
2222
python_requires=">=3.10",
23-
version="0.3.1",
23+
version="0.3.2",
2424
description="A trainable PyTorch reproduction of AlphaFold 3.",
2525
author="Bytedance Inc.",
2626
url="https://github.com/bytedance/Protenix",

0 commit comments

Comments
 (0)