diff --git a/psiflow/models/mace_utils.py b/psiflow/models/mace_utils.py index 315e85f..31a71e9 100644 --- a/psiflow/models/mace_utils.py +++ b/psiflow/models/mace_utils.py @@ -752,11 +752,16 @@ def run(rank: int, args: argparse.Namespace, world_size: int) -> None: } for swa_eval in swas: - epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=swa_eval, - device=device, - ) + try: + epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=swa_eval, + device=device, + ) + except BaseException as e: + print('failed to load checkpoint for swa:{}'.format(swa_eval)) + print(e) + continue model.to(device) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) diff --git a/tests/test_models.py b/tests/test_models.py index 6935aac..40c70cf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -83,7 +83,7 @@ def test_mace_train(gpu, mace_config, dataset, tmp_path): # it with the manually computed value training = dataset[:-5] validation = dataset[-5:] - mace_config["start_swa"] = 1000 + mace_config["start_swa"] = 100 model = MACE(**mace_config) model.initialize(training) hamiltonian0 = model.create_hamiltonian()