Skip to content

Commit 2f459d3

Browse files
committed
Address PR comments: refactor test to use boring classes and clarify assertion
1 parent 34305db commit 2f459d3

File tree

1 file changed

+3
-36
lines changed

1 file changed

+3
-36
lines changed

tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -212,40 +212,7 @@ def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
212212
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
213213
import os
214214

215-
class RandomDataset(Dataset):
216-
def __init__(self, size, length):
217-
self.len = length
218-
self.data = torch.randn(length, size)
219-
220-
def __getitem__(self, index):
221-
return self.data[index]
222-
223-
def __len__(self):
224-
return self.len
225-
226-
class BoringModel(LightningModule):
227-
def __init__(self):
228-
super().__init__()
229-
self.layer = torch.nn.Linear(32, 2)
230-
231-
def forward(self, x):
232-
return self.layer(x)
233-
234-
def training_step(self, batch, batch_idx):
235-
loss = self(batch).sum()
236-
self.log("train_loss", loss)
237-
return {"loss": loss}
238-
239-
def validation_step(self, batch, batch_idx):
240-
loss = self(batch).sum()
241-
self.log("valid_loss", loss)
242-
243-
def test_step(self, batch, batch_idx):
244-
loss = self(batch).sum()
245-
self.log("test_loss", loss)
246-
247-
def configure_optimizers(self):
248-
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
215+
from lightning.pytorch.demos.boring_classes import BoringModel
249216

250217
trainer = Trainer(
251218
default_root_dir=tmp_path,
@@ -257,10 +224,10 @@ def configure_optimizers(self):
257224
)
258225

259226
model = BoringModel()
260-
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 64), batch_size=2))
227+
trainer.fit(model)
261228

262229
last_ckpt = tmp_path / "last.ckpt"
263230
assert last_ckpt.exists()
264-
# With the fix, it should not be a symlink to itself
231+
# With the fix, if a symlink exists, it should not point to itself (preventing recursion)
265232
if os.path.islink(str(last_ckpt)):
266233
assert os.readlink(str(last_ckpt)) != str(last_ckpt)

0 commit comments

Comments
 (0)