diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index f165280d..5e347077 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed +- Fix callsites of handler.async_save to handle returned None. + + ## [0.5.23] - 2024-07-26 ### Changed diff --git a/checkpoint/orbax/checkpoint/array_checkpoint_handler.py b/checkpoint/orbax/checkpoint/array_checkpoint_handler.py index ad53cd4b..039f8d03 100644 --- a/checkpoint/orbax/checkpoint/array_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/array_checkpoint_handler.py @@ -106,8 +106,9 @@ async def async_save(): commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type # Futures are already running, so sequential waiting is equivalent to # concurrent waiting. - for f in commit_futures: - f.result() # Block on result. + if commit_futures: # May be None. + for f in commit_futures: + f.result() # Block on result. asyncio.run(async_save()) @@ -154,9 +155,7 @@ def restore( path=checkpoint_path, parent_dir=directory, skip_deserialize=False, - is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint( - directory - ), + is_ocdbt_checkpoint=type_handlers.is_ocdbt_checkpoint(directory), ) restore_type = restore_args.restore_type if restore_type is None: diff --git a/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py index 73d01ca3..917eebb3 100644 --- a/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py @@ -404,7 +404,11 @@ async def async_save( _maybe_raise_reserved_item_error(item_name) handler = self._get_or_set_handler(item_name, arg) if isinstance(handler, AsyncCheckpointHandler): - futures.extend(await handler.async_save(item_directory.get(), args=arg)) + commit_futures = await handler.async_save( + item_directory.get(), args=arg + ) + if commit_futures is not None: + futures.extend(commit_futures) else: handler.save(item_directory.get(), args=arg) return futures diff --git a/checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py b/checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py index a643d586..726ade05 100644 --- a/checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/random_key_checkpoint_handler.py @@ -130,8 +130,9 @@ async def async_save(): commit_futures = await self.async_save(directory, *args, **kwargs) # pytype: disable=bad-return-type # Futures are already running, so sequential waiting is equivalent to # concurrent waiting. - for f in commit_futures: - f.result() # Block on result. + if commit_futures: # May be None. + for f in commit_futures: + f.result() # Block on result. asyncio.run(async_save()) @@ -158,9 +159,7 @@ def restore( }), ) - return self.post_restore( - result[self._key_name], result[self._key_metadata] - ) + return self.post_restore(result[self._key_name], result[self._key_metadata]) def finalize(self, directory: epath.Path): self._handler.finalize(directory) @@ -284,6 +283,7 @@ class NumpyRandomKeySaveArgs(CheckpointArgs): Attributes: item (required): a Numpy random key in legacy or nonlegacy format """ + item: NumpyRandomKeyType @@ -291,4 +291,5 @@ class NumpyRandomKeySaveArgs(CheckpointArgs): @dataclasses.dataclass class NumpyRandomKeyRestoreArgs(CheckpointArgs): """Numpy random key restore args.""" + pass