Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:

@staticmethod
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
if trainer.is_global_zero:
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
if os.path.islink(linkpath) or os.path.isfile(linkpath):
os.remove(linkpath)
elif os.path.isdir(linkpath):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,28 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm
expected = max(val_scores) # last/maximum value occurs at final validation epoch
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)


def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
import os

from lightning.pytorch.demos.boring_classes import BoringModel

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)],
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
)

model = BoringModel()
trainer.fit(model)

last_ckpt = tmp_path / "last.ckpt"
assert last_ckpt.exists()
# With the fix, if a symlink exists, it should not point to itself (preventing recursion)
if os.path.islink(str(last_ckpt)):
assert os.readlink(str(last_ckpt)) != str(last_ckpt)
Loading