diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 3e93ce56b3..8ae247fabf 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -882,7 +882,8 @@ def _get_tmp_dir(self): if delete_local: # delete files locally, forcing trainer to look in object store - shutil.rmtree('first') + assert trainer_1._checkpoint_saver is not None + shutil.rmtree(trainer_1._checkpoint_saver.folder) trainer_2 = self.get_trainer( latest_filename=latest_filename,