From ba16e95af85ae2a8ee9aaabb44f22f0c85221042 Mon Sep 17 00:00:00 2001 From: Piotr Stanczyk Date: Wed, 15 May 2024 23:32:36 -0700 Subject: [PATCH] Add user-configured callback to async_checkpointer. PiperOrigin-RevId: 634228569 --- checkpoint/orbax/checkpoint/async_checkpointer.py | 4 ++++ checkpoint/orbax/checkpoint/checkpoint_manager.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/checkpoint/orbax/checkpoint/async_checkpointer.py b/checkpoint/orbax/checkpoint/async_checkpointer.py index 94abc3e2..945c7409 100644 --- a/checkpoint/orbax/checkpoint/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/async_checkpointer.py @@ -252,6 +252,7 @@ def __init__( # is provided. active_processes: Optional[Set[int]] = None, barrier_sync_fn: Optional[BarrierSyncFn] = None, + post_finalization_callback: Optional[Callable[[], None]] = None ): jax.monitoring.record_event('/jax/orbax/async_checkpointer/init') if not checkpoint_args.has_registered_args(handler): @@ -266,6 +267,7 @@ def __init__( self._handler = handler self._primary_host = primary_host self._active_processes = active_processes + self._post_finalization_callback = post_finalization_callback # TODO(dicentra): consider folding into AsyncCheckpointer directly. self._async_manager = _AsyncManager( @@ -325,6 +327,8 @@ def save( # Directory is the final directory. def _callback() -> None: self._handler.finalize(tmpdir) + if self._post_finalization_callback is not None: + self._post_finalization_callback() _on_commit_callback(tmpdir, directory, checkpoint_start_time) self._async_manager.start_async_commit( diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 052c46db..597600d4 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -134,6 +134,7 @@ class AsyncOptions: timeout_secs: int = 300 barrier_sync_fn: Optional[async_checkpointer.BarrierSyncFn] = None + post_finalization_callback: Optional[Callable[[], None]] = None @dataclasses.dataclass @@ -618,6 +619,7 @@ def _configure_checkpointer_common( primary_host=self._multiprocessing_options.primary_host, barrier_sync_fn=options.async_options.barrier_sync_fn, active_processes=self._multiprocessing_options.active_processes, + post_finalization_callback=options.async_options.post_finalization_callback, ) else: return async_checkpointer.AsyncCheckpointer(