diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ddc12a92e9f56..a3f147aafb39f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -24,7 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) +- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) + + +- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) ### Removed diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index dfc0cebb8d07d..415e1dcac309b 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: @staticmethod def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None: - if trainer.is_global_zero: + if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath): if os.path.islink(linkpath) or os.path.isfile(linkpath): os.remove(linkpath) elif os.path.isdir(linkpath): diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py index fba9e865debfd..1b56453220edd 100644 --- a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py @@ -1,4 +1,5 @@ import math +import os from datetime import timedelta import pytest @@ -9,6 +10,7 @@ from lightning.pytorch import LightningModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel class TinyDataset(Dataset): @@ -206,3 +208,24 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm expected = max(val_scores) # last/maximum value occurs at final validation epoch actual = float(ckpt.best_model_score) assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +def test_model_checkpoint_save_last_link_symlink_bug(tmp_path): + """Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink.""" + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)], + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + model = BoringModel() + trainer.fit(model) + + last_ckpt = tmp_path / "last.ckpt" + assert last_ckpt.exists() + # With the fix, if a symlink exists, it should not point to itself (preventing recursion) + if os.path.islink(str(last_ckpt)): + assert os.readlink(str(last_ckpt)) != str(last_ckpt)