Skip to content

Commit

Permalink
Pass FileOptions instead path_permission_mode to checkpointers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671743702
  • Loading branch information
Orbax Authors committed Sep 6, 2024
1 parent 96f28e6 commit 79e2b01
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 17 deletions.
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __init__(
*,
async_options: options_lib.AsyncOptions = options_lib.AsyncOptions(),
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
path_permission_mode: Optional[int] = None,
file_options: options_lib.FileOptions = options_lib.FileOptions(),
checkpoint_metadata_store: Optional[
checkpoint.CheckpointMetadataStore
] = None,
Expand Down Expand Up @@ -294,7 +294,7 @@ def __init__(
else f'{multiprocessing_options.barrier_sync_key_prefix}.{unique_class_id}'
)
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._path_permission_mode = path_permission_mode # e.g. 0o750
self._file_options = file_options
self._checkpoint_metadata_store = (
checkpoint_metadata_store
or checkpoint.checkpoint_metadata_store(enable_write=True)
Expand Down
8 changes: 5 additions & 3 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ def __init__(
'item_names'. See :py:class:`CheckpointHandlerRegistry` for more
details.
"""
logging.info('DEBUG: __init__ directory %s', directory)

jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init')
logging.info(
'[process=%s][thread=%s] CheckpointManager init: checkpointers=%s,'
Expand Down Expand Up @@ -691,7 +693,7 @@ def __init__(
multiprocessing_options=self._multiprocessing_options
),
multiprocessing_options=self._options.multiprocessing_options,
path_permission_mode=self._options.file_options.path_permission_mode,
file_options=self._options.file_options,
checkpoint_metadata_store=self._blocking_checkpoint_metadata_store,
temporary_path_class=self._options.temporary_path_class,
)
Expand Down Expand Up @@ -761,15 +763,15 @@ def _configure_checkpointer_common(
handler,
multiprocessing_options=options.multiprocessing_options,
async_options=options.async_options or AsyncOptions(),
path_permission_mode=options.file_options.path_permission_mode,
file_options=options.file_options,
checkpoint_metadata_store=self._non_blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
)
else:
return Checkpointer(
handler,
multiprocessing_options=options.multiprocessing_options,
path_permission_mode=options.file_options.path_permission_mode,
file_options=options.file_options,
checkpoint_metadata_store=self._blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
)
Expand Down
8 changes: 3 additions & 5 deletions checkpoint/orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
handler: checkpoint_handler.CheckpointHandler,
*,
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
path_permission_mode: Optional[int] = None,
file_options: options_lib.FileOptions = options_lib.FileOptions(),
checkpoint_metadata_store: Optional[
checkpoint.CheckpointMetadataStore
] = None,
Expand All @@ -122,7 +122,7 @@ def __init__(
self._barrier_sync_key_prefix = (
multiprocessing_options.barrier_sync_key_prefix
)
self._path_permission_mode = path_permission_mode # e.g. 0o750
self._file_options = file_options
self._temporary_path_class = temporary_path_class

# If not provided then use checkpoint_metadata_store with blocking writes.
Expand Down Expand Up @@ -153,9 +153,7 @@ async def create_temporary_path(
directory,
checkpoint_metadata_store=self._checkpoint_metadata_store,
multiprocessing_options=multiprocessing_options,
file_options=options_lib.FileOptions(
path_permission_mode=self._path_permission_mode
),
file_options=self._file_options,
)
await atomicity.create_all(
[tmpdir], multiprocessing_options=multiprocessing_options
Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ class FileOptions:
"""

path_permission_mode: Optional[int] = None

data_governance_annotations: Optional[annotations_pb2.Annotations] = None
3 changes: 0 additions & 3 deletions checkpoint/orbax/checkpoint/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@
from orbax.checkpoint import metadata
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint.path import async_utils
from orbax.checkpoint.path import step as step_lib


TMP_DIR_SUFFIX = step_lib.TMP_DIR_SUFFIX

Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/path/atomicity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Tests for atomicity.py."""

import unittest
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
Expand Down Expand Up @@ -151,5 +150,6 @@ async def test_create_all(self):
self.assertTrue(paths[1].exists())



if __name__ == '__main__':
absltest.main()
6 changes: 3 additions & 3 deletions checkpoint/orbax/checkpoint/standard_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
*,
async_options: options_lib.AsyncOptions = options_lib.AsyncOptions(),
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
path_permission_mode: Optional[int] = None,
file_options: options_lib.FileOptions = options_lib.FileOptions(),
checkpoint_metadata_store: Optional[
checkpoint.CheckpointMetadataStore
] = None,
Expand All @@ -73,7 +73,7 @@ def __init__(
Args:
async_options: See superclass documentation.
multiprocessing_options: See superclass documentation.
path_permission_mode: See superclass documentation.
file_options: See superclass documentation.
checkpoint_metadata_store: See superclass documentation.
temporary_path_class: See superclass documentation.
**kwargs: Additional init args passed to StandardCHeckpointHandler. See
Expand All @@ -86,7 +86,7 @@ def __init__(
),
async_options=async_options,
multiprocessing_options=multiprocessing_options,
path_permission_mode=path_permission_mode,
file_options=file_options,
checkpoint_metadata_store=checkpoint_metadata_store,
temporary_path_class=temporary_path_class,
)
Expand Down

0 comments on commit 79e2b01

Please sign in to comment.