Skip to content

Commit

Permalink
Fix callsites of handler.async_save to handle returned None.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657223381
  • Loading branch information
niketkumar authored and Orbax Authors committed Jul 29, 2024
1 parent 529b25c commit 0b6aa48
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- Fix callsites of handler.async_save to handle returned None.


## [0.5.23] - 2024-07-26

### Changed
Expand Down
9 changes: 4 additions & 5 deletions checkpoint/orbax/checkpoint/array_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ async def async_save():
commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
for f in commit_futures:
f.result() # Block on result.
if commit_futures: # May be None.
for f in commit_futures:
f.result() # Block on result.

asyncio.run(async_save())

Expand Down Expand Up @@ -154,9 +155,7 @@ def restore(
path=checkpoint_path,
parent_dir=directory,
skip_deserialize=False,
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(
directory
),
is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory),
)
restore_type = restore_args.restore_type
if restore_type is None:
Expand Down
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,11 @@ async def async_save(
_maybe_raise_reserved_item_error(item_name)
handler = self._get_or_set_handler(item_name, arg)
if isinstance(handler, AsyncCheckpointHandler):
futures.extend(await handler.async_save(item_directory.get(), args=arg))
commit_futures = await handler.async_save(
item_directory.get(), args=arg
)
if commit_futures is not None:
futures.extend(commit_futures)
else:
handler.save(item_directory.get(), args=arg)
return futures
Expand Down
11 changes: 6 additions & 5 deletions checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ async def async_save():
commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type
# Futures are already running, so sequential waiting is equivalent to
# concurrent waiting.
for f in commit_futures:
f.result() # Block on result.
if commit_futures: # May be None.
for f in commit_futures:
f.result() # Block on result.

asyncio.run(async_save())

Expand All @@ -158,9 +159,7 @@ def restore(
}),
)

return self.post_restore(
result[self._key_name], result[self._key_metadata]
)
return self.post_restore(result[self._key_name], result[self._key_metadata])

def finalize(self, directory: epath.Path):
self._handler.finalize(directory)
Expand Down Expand Up @@ -284,11 +283,13 @@ class NumpyRandomKeySaveArgs(CheckpointArgs):
Attributes:
item (required): a Numpy random key in legacy or nonlegacy format
"""

item: NumpyRandomKeyType


@register_with_handler(NumpyRandomKeyCheckpointHandler, for_restore=True)
@dataclasses.dataclass
class NumpyRandomKeyRestoreArgs(CheckpointArgs):
"""Numpy random key restore args."""

pass

0 comments on commit 0b6aa48

Please sign in to comment.