Skip to content

Commit

Permalink
CheckpointManager: add is_saving_in_progress method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649440310
  • Loading branch information
Orbax Authors committed Jul 11, 2024
1 parent 1e06498 commit 384a6c9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Rolled forward change to improve TensorStore I/O efficiency.
- Memory efficient broadcasting from one model replica to others.
- Ability to check if a checkpoint save is in progress.

### Changed
- Allow one directory creation request per item rather than 1 per item per host.
Expand Down
9 changes: 7 additions & 2 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,6 @@ def wait_until_finished(self):
"""See superclass documentation."""
t = self._finalize_thread
if t is not None:
self._finalize_thread = None
try:
t.join()
except BaseException as e: # pylint:disable=broad-exception-caught
Expand All @@ -1472,6 +1471,8 @@ def wait_until_finished(self):
assert self._checkpoints
self._checkpoints = self._checkpoints[:-1]
raise e
finally:
self._finalize_thread = None
# Additional work is being done on process 0 of the finalize threads.
# When joining the threads, we must wait for all threads to complete
# before proceeding.
Expand All @@ -1485,6 +1486,10 @@ def wait_until_finished(self):
processes=self._multiprocessing_options.active_processes,
)

def is_saving_in_progress(self) -> bool:
"""Returns whether a checkpoint save is in progress."""
return self._finalize_thread is not None

def check_for_errors(self):
"""See superclass documentation."""
if is_async_checkpointer(self._checkpointer):
Expand Down Expand Up @@ -1528,7 +1533,7 @@ def _finalize_checkpoint(self, step: int):
)

def _finalize(self, directory: epath.Path, steps_to_remove: List[int]):
"""Cleans up old checkpoints and synchronizes hosts."""
"""Finalizes individual items and starts garbage collection."""
self._non_blocking_checkpoint_metadata_store.wait_until_finished()
self._wait_for_checkpointers()
# If an error is encountered while waiting for commit futures to complete,
Expand Down

0 comments on commit 384a6c9

Please sign in to comment.