Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:

@staticmethod
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
if trainer.is_global_zero:
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
if os.path.islink(linkpath) or os.path.isfile(linkpath):
os.remove(linkpath)
elif os.path.isdir(linkpath):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,61 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm
expected = max(val_scores) # last/maximum value occurs at final validation epoch
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)


def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
import os

class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len

class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
return self.layer(x)

def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)

def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)],
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
)

model = BoringModel()
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2))

last_ckpt = tmp_path / "last.ckpt"
assert last_ckpt.exists()
# With the fix, it should not be a symlink to itself
if os.path.islink(str(last_ckpt)):
assert os.readlink(str(last_ckpt)) != str(last_ckpt)
Loading