Skip to content

Commit

Permalink
allow errors during checkpoint loading (and
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jul 1, 2024
1 parent f70dbe2 commit 290e0b4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions psiflow/models/mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,11 +752,16 @@ def run(rank: int, args: argparse.Namespace, world_size: int) -> None:
}

for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
try:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
except BaseException as e:
print('failed to load checkpoint for swa:{}'.format(swa_eval))
print(e)
continue
model.to(device)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_mace_train(gpu, mace_config, dataset, tmp_path):
# it with the manually computed value
training = dataset[:-5]
validation = dataset[-5:]
mace_config["start_swa"] = 1000
mace_config["start_swa"] = 100
model = MACE(**mace_config)
model.initialize(training)
hamiltonian0 = model.create_hamiltonian()
Expand Down

0 comments on commit 290e0b4

Please sign in to comment.