diff --git a/src/fidder/model/_tests/test_download_checkpoint.py b/src/fidder/model/_tests/test_download_checkpoint.py index 1f8c593..8fe6497 100644 --- a/src/fidder/model/_tests/test_download_checkpoint.py +++ b/src/fidder/model/_tests/test_download_checkpoint.py @@ -4,5 +4,5 @@ def test_download_and_load_latest_checkpoint(): checkpoint_file = get_latest_checkpoint() model = Fidder() - model.load_from_checkpoint(checkpoint_file) + model.load_from_checkpoint(checkpoint_file, map_location="cpu") assert isinstance(model, Fidder) diff --git a/src/fidder/predict/predict.py b/src/fidder/predict/predict.py index 3869616..4ef3dc2 100644 --- a/src/fidder/predict/predict.py +++ b/src/fidder/predict/predict.py @@ -47,7 +47,7 @@ def predict_fiducial_mask( # prepare model if model_checkpoint_file is None: model_checkpoint_file = get_latest_checkpoint() - model = Fidder.load_from_checkpoint(model_checkpoint_file) + model = Fidder.load_from_checkpoint(model_checkpoint_file, map_location="cpu") model.eval() # predict