Skip to content

Commit 34305db

Browse files
committed
Fix for issue #21110: Prevent recursive symlink creation in ModelCheckpoint
- Added a check in _link_checkpoint to compare absolute paths of filepath and linkpath - Only create symlink if paths differ, avoiding self-linking when save_last='link' and save_top_k=-1 - Updated test to assert the fix prevents the recursive symlink bug
1 parent cd30ce4 commit 34305db

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
484484

485485
@staticmethod
486486
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
487-
if trainer.is_global_zero:
487+
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
488488
if os.path.islink(linkpath) or os.path.isfile(linkpath):
489489
os.remove(linkpath)
490490
elif os.path.isdir(linkpath):

tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,61 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm
206206
expected = max(val_scores) # last/maximum value occurs at final validation epoch
207207
actual = float(ckpt.best_model_score)
208208
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
209+
210+
211+
def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
212+
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
213+
import os
214+
215+
class RandomDataset(Dataset):
216+
def __init__(self, size, length):
217+
self.len = length
218+
self.data = torch.randn(length, size)
219+
220+
def __getitem__(self, index):
221+
return self.data[index]
222+
223+
def __len__(self):
224+
return self.len
225+
226+
class BoringModel(LightningModule):
227+
def __init__(self):
228+
super().__init__()
229+
self.layer = torch.nn.Linear(32, 2)
230+
231+
def forward(self, x):
232+
return self.layer(x)
233+
234+
def training_step(self, batch, batch_idx):
235+
loss = self(batch).sum()
236+
self.log("train_loss", loss)
237+
return {"loss": loss}
238+
239+
def validation_step(self, batch, batch_idx):
240+
loss = self(batch).sum()
241+
self.log("valid_loss", loss)
242+
243+
def test_step(self, batch, batch_idx):
244+
loss = self(batch).sum()
245+
self.log("test_loss", loss)
246+
247+
def configure_optimizers(self):
248+
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
249+
250+
trainer = Trainer(
251+
default_root_dir=tmp_path,
252+
max_epochs=2,
253+
callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)],
254+
enable_checkpointing=True,
255+
enable_model_summary=False,
256+
logger=False,
257+
)
258+
259+
model = BoringModel()
260+
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2))
261+
262+
last_ckpt = tmp_path / "last.ckpt"
263+
assert last_ckpt.exists()
264+
# With the fix, it should not be a symlink to itself
265+
if os.path.islink(str(last_ckpt)):
266+
assert os.readlink(str(last_ckpt)) != str(last_ckpt)

0 commit comments

Comments
 (0)