@@ -2128,32 +2128,32 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
2128
2128
2129
2129
def test_save_last_only_when_checkpoint_saved (tmp_path ):
2130
2130
"""Test that save_last only creates last.ckpt when another checkpoint is actually saved."""
2131
-
2131
+
2132
2132
class SelectiveModel (BoringModel ):
2133
2133
def __init__ (self ):
2134
2134
super ().__init__ ()
2135
2135
self .validation_step_outputs = []
2136
-
2136
+
2137
2137
def validation_step (self , batch , batch_idx ):
2138
2138
outputs = super ().validation_step (batch , batch_idx )
2139
2139
epoch = self .trainer .current_epoch
2140
2140
loss = torch .tensor (1.0 - epoch * 0.1 ) if epoch % 2 == 0 else torch .tensor (1.0 + epoch * 0.1 )
2141
2141
outputs ["val_loss" ] = loss
2142
2142
self .validation_step_outputs .append (outputs )
2143
2143
return outputs
2144
-
2144
+
2145
2145
def on_validation_epoch_end (self ):
2146
2146
if self .validation_step_outputs :
2147
2147
avg_loss = torch .stack ([x ["val_loss" ] for x in self .validation_step_outputs ]).mean ()
2148
2148
self .log ("val_loss" , avg_loss )
2149
2149
self .validation_step_outputs .clear ()
2150
2150
2151
2151
model = SelectiveModel ()
2152
-
2152
+
2153
2153
checkpoint_callback = ModelCheckpoint (
2154
2154
dirpath = tmp_path ,
2155
2155
filename = "best-{epoch}-{val_loss:.2f}" ,
2156
- monitor = "val_loss" ,
2156
+ monitor = "val_loss" ,
2157
2157
save_last = True ,
2158
2158
save_top_k = 1 ,
2159
2159
mode = "min" ,
@@ -2177,4 +2177,6 @@ def on_validation_epoch_end(self):
2177
2177
checkpoint_names = [f .name for f in checkpoint_files ]
2178
2178
assert "last.ckpt" in checkpoint_names , "last.ckpt should exist since checkpoints were saved"
2179
2179
expected_files = 2 # best checkpoint + last.ckpt
2180
- assert len (checkpoint_files ) == expected_files , f"Expected { expected_files } files, got { len (checkpoint_files )} : { checkpoint_names } "
2180
+ assert len (checkpoint_files ) == expected_files , (
2181
+ f"Expected { expected_files } files, got { len (checkpoint_files )} : { checkpoint_names } "
2182
+ )
0 commit comments