diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 9d0a33afd0b77..637dfcd9b1671 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -13,6 +13,7 @@ # limitations under the License. """Utilities related to data saving/loading.""" +import errno import io import logging from pathlib import Path @@ -84,10 +85,16 @@ def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None log.debug(f"Saving checkpoint: {filepath}") torch.save(checkpoint, bytesbuffer) - # We use a transaction here to avoid file corruption if the save gets interrupted - fs, urlpath = fsspec.core.url_to_fs(str(filepath)) - with fs.transaction, fs.open(urlpath, "wb") as f: - f.write(bytesbuffer.getvalue()) + try: + # We use a transaction here to avoid file corruption if the save gets interrupted + fs, urlpath = fsspec.core.url_to_fs(str(filepath)) + with fs.transaction, fs.open(urlpath, "wb") as f: + f.write(bytesbuffer.getvalue()) + except PermissionError as e: + if isinstance(e.__context__, OSError) and getattr(e.__context__, "errno", None) == errno.EXDEV: + raise RuntimeError( + 'Upgrade fsspec to enable cross-device local checkpoints: pip install "fsspec[http]>=2025.5.0"', + ) from e def _is_object_storage(fs: AbstractFileSystem) -> bool: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 0095367e9187a..18ef679312a66 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,12 +11,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) +- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780)) + ### Changed - - ### Removed - diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 85bfb65c0ea6e..6b7b2831a2e04 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -155,6 +155,10 @@ class ModelCheckpoint(Checkpoint): If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, only ``best_model_path`` will be reloaded and a warning will be issued. + If you provide a ``filename`` on a mounted device where changing permissions is not allowed (causing ``chmod`` + to raise a ``PermissionError``), install `fsspec>=2025.5.0`. Then the error is caught, the file's permissions + remain unchanged, and the checkpoint is still saved. Otherwise, no checkpoint will be saved and training stops. + Raises: MisconfigurationException: If ``save_top_k`` is smaller than ``-1``, diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 722742a3ccae0..662fd99d1b12c 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import errno import os +import re from unittest import mock from unittest.mock import ANY, Mock +import fsspec import pytest import torch @@ -105,6 +108,31 @@ def test_hpc_max_ckpt_version(tmp_path): ) +def test_local_cross_device_checkpoint(tmpdir): + """Test that the _CheckpointConnector can write local cross-device files or raises an error if fsspec<2025.5.0.""" + model = BoringModel() + # hardcoding dir since `tmp_path` can be windows path + trainer = Trainer( + default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1 + ) + trainer.fit(model) + # Simulate the behavior of fsspec when writing to a local file system but other device. + with ( + mock.patch("os.rename", side_effect=OSError(errno.EXDEV, "Invalid cross-device link")), + mock.patch("os.chmod", side_effect=PermissionError("Operation not permitted")), + ): + if fsspec.__version__ < "2025.5.0": + with pytest.raises( + RuntimeError, + match=re.escape( + 'Upgrade fsspec to enable cross-device local checkpoints: pip install "fsspec[http]>=2025.5.0"' + ), + ): + trainer.save_checkpoint(tmpdir + "/test_ckpt_for_fsspec/hpc_ckpt.ckpt") + else: + trainer.save_checkpoint(tmpdir + "/test_ckpt_for_fsspec/hpc_ckpt.ckpt") + + def test_ckpt_for_fsspec(): """Test that the _CheckpointConnector is able to write to fsspec file systems.""" model = BoringModel()