diff --git a/checkpoint/orbax/checkpoint/path/atomicity.py b/checkpoint/orbax/checkpoint/path/atomicity.py index 1c239d66..2c014330 100644 --- a/checkpoint/orbax/checkpoint/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/path/atomicity.py @@ -100,6 +100,10 @@ def get(self) -> epath.Path: """Constructs the temporary path without actually creating it.""" ... + def get_final(self) -> epath.Path: + """Returns the final path without creating it.""" + ... + def create( self, *, @@ -288,6 +292,9 @@ def match(cls, temporary_path: epath.Path, final_path: epath.Path) -> bool: def get(self) -> epath.Path: return self._tmp_path + def get_final(self) -> epath.Path: + return self._final_path + def create( self, *, @@ -406,6 +413,9 @@ def match(cls, temporary_path: epath.Path, final_path: epath.Path) -> bool: def get(self) -> epath.Path: return self._tmp_path + def get_final(self) -> epath.Path: + return self._final_path + def create( self, *, @@ -482,4 +492,4 @@ def on_commit_callback( """ tmp_dir.finalize() step_lib.record_saved_duration(checkpoint_start_time) - logging.info('Finished saving checkpoint to `%s`.', tmp_dir.get()) + logging.info('Finished saving checkpoint to `%s`.', tmp_dir.get_final())