diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 82c53411..0d2f6936 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Rolled forward change to improve TensorStore I/O efficiency. - Memory efficient broadcasting from one model replica to others. +- Ability to check if a checkpoint save is in progress. ### Changed - Allow one directory creation request per item rather than 1 per item per host. diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index d2e510f0..64bc00a5 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -1463,7 +1463,6 @@ def wait_until_finished(self): """See superclass documentation.""" t = self._finalize_thread if t is not None: - self._finalize_thread = None try: t.join() except BaseException as e: # pylint:disable=broad-exception-caught @@ -1472,6 +1471,8 @@ def wait_until_finished(self): assert self._checkpoints self._checkpoints = self._checkpoints[:-1] raise e + finally: + self._finalize_thread = None # Additional work is being done on process 0 of the finalize threads. # When joining the threads, we must wait for all threads to complete # before proceeding. @@ -1485,6 +1486,10 @@ def wait_until_finished(self): processes=self._multiprocessing_options.active_processes, ) + def is_saving_in_progress(self) -> bool: + """Returns whether a checkpoint save is in progress.""" + return self._finalize_thread is not None + def check_for_errors(self): """See superclass documentation.""" if is_async_checkpointer(self._checkpointer): @@ -1528,7 +1533,7 @@ def _finalize_checkpoint(self, step: int): ) def _finalize(self, directory: epath.Path, steps_to_remove: List[int]): - """Cleans up old checkpoints and synchronizes hosts.""" + """Finalizes individual items and starts garbage collection.""" self._non_blocking_checkpoint_metadata_store.wait_until_finished() self._wait_for_checkpointers() # If an error is encountered while waiting for commit futures to complete,