Skip to content

Commit

Permalink
Add user-configured callback to async_checkpointer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634228569
  • Loading branch information
qstanczyk authored and Orbax Authors committed May 16, 2024
1 parent 0493c82 commit ba16e95
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def __init__(
# is provided.
active_processes: Optional[Set[int]] = None,
barrier_sync_fn: Optional[BarrierSyncFn] = None,
post_finalization_callback: Optional[Callable[[], None]] = None
):
jax.monitoring.record_event('/jax/orbax/async_checkpointer/init')
if not checkpoint_args.has_registered_args(handler):
Expand All @@ -266,6 +267,7 @@ def __init__(
self._handler = handler
self._primary_host = primary_host
self._active_processes = active_processes
self._post_finalization_callback = post_finalization_callback

# TODO(dicentra): consider folding into AsyncCheckpointer directly.
self._async_manager = _AsyncManager(
Expand Down Expand Up @@ -325,6 +327,8 @@ def save(
# Directory is the final directory.
def _callback() -> None:
self._handler.finalize(tmpdir)
if self._post_finalization_callback is not None:
self._post_finalization_callback()
_on_commit_callback(tmpdir, directory, checkpoint_start_time)

self._async_manager.start_async_commit(
Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class AsyncOptions:

timeout_secs: int = 300
barrier_sync_fn: Optional[async_checkpointer.BarrierSyncFn] = None
post_finalization_callback: Optional[Callable[[], None]] = None


@dataclasses.dataclass
Expand Down Expand Up @@ -618,6 +619,7 @@ def _configure_checkpointer_common(
primary_host=self._multiprocessing_options.primary_host,
barrier_sync_fn=options.async_options.barrier_sync_fn,
active_processes=self._multiprocessing_options.active_processes,
post_finalization_callback=options.async_options.post_finalization_callback,
)
else:
return async_checkpointer.AsyncCheckpointer(
Expand Down

0 comments on commit ba16e95

Please sign in to comment.