@@ -212,40 +212,7 @@ def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
212
212
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
213
213
import os
214
214
215
- class RandomDataset (Dataset ):
216
- def __init__ (self , size , length ):
217
- self .len = length
218
- self .data = torch .randn (length , size )
219
-
220
- def __getitem__ (self , index ):
221
- return self .data [index ]
222
-
223
- def __len__ (self ):
224
- return self .len
225
-
226
- class BoringModel (LightningModule ):
227
- def __init__ (self ):
228
- super ().__init__ ()
229
- self .layer = torch .nn .Linear (32 , 2 )
230
-
231
- def forward (self , x ):
232
- return self .layer (x )
233
-
234
- def training_step (self , batch , batch_idx ):
235
- loss = self (batch ).sum ()
236
- self .log ("train_loss" , loss )
237
- return {"loss" : loss }
238
-
239
- def validation_step (self , batch , batch_idx ):
240
- loss = self (batch ).sum ()
241
- self .log ("valid_loss" , loss )
242
-
243
- def test_step (self , batch , batch_idx ):
244
- loss = self (batch ).sum ()
245
- self .log ("test_loss" , loss )
246
-
247
- def configure_optimizers (self ):
248
- return torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
215
+ from lightning .pytorch .demos .boring_classes import BoringModel
249
216
250
217
trainer = Trainer (
251
218
default_root_dir = tmp_path ,
@@ -257,10 +224,10 @@ def configure_optimizers(self):
257
224
)
258
225
259
226
model = BoringModel ()
260
- trainer .fit (model , train_dataloaders = DataLoader ( RandomDataset ( 32 , 64 ), batch_size = 2 ) )
227
+ trainer .fit (model )
261
228
262
229
last_ckpt = tmp_path / "last.ckpt"
263
230
assert last_ckpt .exists ()
264
- # With the fix, it should not be a symlink to itself
231
+ # With the fix, if a symlink exists, it should not point to itself (preventing recursion)
265
232
if os .path .islink (str (last_ckpt )):
266
233
assert os .readlink (str (last_ckpt )) != str (last_ckpt )
0 commit comments