Skip to content

Commit

Permalink
Add option to disable sharding file write at TypeHandler level.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657297029
  • Loading branch information
liangyaning33 authored and Orbax Authors committed Jul 29, 2024
1 parent c6de78a commit 5b0993a
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ def __init__(
metadata_key: Optional[str] = None,
primary_host: Optional[int] = 0,
replica_id: Optional[int] = 0,
enable_write_sharding_file: bool = True,
):
"""Constructor.
Expand All @@ -1181,10 +1182,13 @@ def __init__(
replica_id: the replica id to be used for saving. Default to 0. If it's
set to None, each shards will pick first replica_id to be used. It's
useful in the case that all hosts are only working with local storage.
enable_write_sharding_file: whether to write sharding file, defaults to
True.
"""
self._metadata_key = metadata_key
self._primary_host = primary_host
self._replica_id = replica_id
self._enable_write_sharding_file = enable_write_sharding_file

logging.info(
'Created `ArrayHandler` with primary_host=%s, replica_id=%s',
Expand Down Expand Up @@ -1361,7 +1365,7 @@ async def serialize(
)
]

if value.sharding is not None:
if self._enable_write_sharding_file and value.sharding is not None:
if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')
tspec_sharding = get_sharding_tensorstore_spec(
Expand Down

0 comments on commit 5b0993a

Please sign in to comment.