Skip to content

Commit

Permalink
DO NOT SUBMIT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654697026
  • Loading branch information
Orbax Authors committed Aug 6, 2024
1 parent f84f454 commit 253636c
Show file tree
Hide file tree
Showing 2 changed files with 396 additions and 0 deletions.
167 changes: 167 additions & 0 deletions checkpoint/orbax/checkpoint/handlers/handler_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Global registry for `CheckpointHandler`s."""

from typing import MutableMapping, Optional, Protocol, Type, Union

from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler

CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandler = checkpoint_handler.CheckpointHandler
HandlerRegistryMapping = MutableMapping[
tuple[Optional[str], type[CheckpointArgs]], CheckpointHandler
]


class CheckpointHandlerRegistry(Protocol):
"""Protocol for `CheckpointHandlerRegistry`."""

def add(
self,
item: Optional[str],
args_type: Type[CheckpointArgs],
handler: Union[CheckpointHandler, Type[CheckpointHandler]],
**kwargs,
):
...

def get(
self,
item: Optional[str],
args_type: Type[CheckpointArgs],
) -> CheckpointHandler:
...

def has(
self,
item: Optional[str],
args_type: Type[CheckpointArgs],
) -> bool:
...

def get_all_entries(
self,
) -> HandlerRegistryMapping:
...


class AlreadyExistsError(ValueError):
"""Raised when an entry already exists in the registry."""


class NoEntryError(KeyError):
"""Raised when no entry exists in the registry."""


class DefaultCheckpointHandlerRegistry(CheckpointHandlerRegistry):
"""Default implementation of `CheckpointHandlerRegistry`.
Inherits from globally registered `CheckpointHandler`s on construction.
"""

def __init__(
self, other_registry: Optional[CheckpointHandlerRegistry] = None
):
self._registry: HandlerRegistryMapping = {}

# Initialize the registry with entries from other registry.
if other_registry:
for (
item,
args_type,
), handler in other_registry.get_all_entries().items():
self.add(item, args_type, handler)

def add(
self,
item: Optional[str],
args_type: Type[CheckpointArgs],
handler: Union[CheckpointHandler, Type[CheckpointHandler]],
):
"""Adds an entry to the registry.
Args:
item: The item name. If None, the entry will be added as a general
`args_type` entry.
args_type: The args type.
handler: The handler. If a type is provided, an instance of the type will
be added to the registry.
Raises:
AlreadyExistsError: If an entry for the given item and args type already
exists in the registry.
"""
if self.has(item, args_type):
raise AlreadyExistsError(
f'Entry for item={item} and args_type={args_type} already'
' exists in the registry.'
)
else:
handler_instance = handler() if isinstance(handler, type) else handler
self._registry[(item, args_type)] = handler_instance

def get(
self,
item: Optional[str],
args_type: Type[CheckpointArgs],
) -> CheckpointHandler:
"""Returns the handler for the given item and args type.
Args:
item: The item name. If None, the entry will be added as a general
`args_type` entry.
args_type: The args type.
If item the item has not been registered, the general `args_type` entry will
be returned if it exists.
Raises:
NoEntryError: If no entry for the given item and args type exists in the
registry.
"""

if self.has(item, args_type):
return self._registry[(item, args_type)]

# Fall back to general `args_type` if there is no entry for the given item
# in the registry.
if item is not None:
if (None, args_type) in self._registry:
return self.get(None, args_type)

raise NoEntryError(
f'No entry for item={item} and args_ty={args_type} in the registry.'
)

def has(self, item: Optional[str], args_type: type[CheckpointArgs]) -> bool:
"""Returns whether an entry for the given item and args type exists in the registry.
Args:
item: The item name or None.
args_type: The args type.
Does not check for fall back to general `args_type` entry.
"""
return (
item,
args_type,
) in self._registry

def get_all_entries(
self,
) -> HandlerRegistryMapping:
"""Returns all entries in the registry."""
return self._registry
229 changes: 229 additions & 0 deletions checkpoint/orbax/checkpoint/handlers/handler_registration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import Optional, Type, Union

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint.handlers import handler_registration


CheckpointHandler = checkpoint_handler.CheckpointHandler
DefaultCheckpointHandlerRegistry = (
handler_registration.DefaultCheckpointHandlerRegistry
)
AlreadyExistsError = handler_registration.AlreadyExistsError
NoEntryError = handler_registration.NoEntryError


class _TestCheckpointHandler(CheckpointHandler):
"""No-op checkpoint handler for testing."""

def save(self, directory: epath.Path, *args, **kwargs) -> None:
del directory, args, kwargs

def restore(self, directory: epath.Path, *args, **kwargs) -> None:
del directory, args, kwargs


@dataclasses.dataclass
class _TestArgs(checkpoint_args.CheckpointArgs):
"""No-op checkpoint args for testing."""

...


class HandlerRegistryTest(parameterized.TestCase):

@parameterized.product(
handler=(_TestCheckpointHandler, _TestCheckpointHandler()),
item=(None, 'item'),
)
def test_add_and_get_entry(
self,
handler: Union[CheckpointHandler, Type[CheckpointHandler]],
item: Optional[str],
):
args_type = _TestArgs
registry = DefaultCheckpointHandlerRegistry()

registry.add(
item,
args_type,
handler,
)

# Check that the entry is added to the registry.
self.assertTrue(registry.has(item, args_type))
# Check that the handler is returned and that it is initialized as an
# object.
self.assertIsInstance(
registry.get(item, args_type),
_TestCheckpointHandler,
)

def test_add_entry_with_existing_item_and_args_type_raises_error(self):
item = 'item'
args_type = _TestArgs
registry = DefaultCheckpointHandlerRegistry()

registry.add(item, args_type, _TestCheckpointHandler)

with self.assertRaisesRegex(
AlreadyExistsError, r'already exists in the registry'
):
registry.add(item, args_type, _TestCheckpointHandler)

def test_get_all_entries(self):
item1 = 'item1'
item2 = 'item2'
args_type = _TestArgs
handler = _TestCheckpointHandler
registry = DefaultCheckpointHandlerRegistry()

registry.add(item1, args_type, handler)
registry.add(item2, args_type, handler)

entries = registry.get_all_entries()
self.assertLen(entries, 2)
self.assertIsInstance(
entries[(item1, args_type)],
handler,
)
self.assertIsInstance(
entries[(item2, args_type)],
handler,
)

def test_instantiate_registry_from_another_registry(self):
item1 = 'item1'
item2 = 'item2'
args_type = _TestArgs
handler = _TestCheckpointHandler

registry1 = DefaultCheckpointHandlerRegistry()
registry1.add(item1, args_type, handler)
registry2 = DefaultCheckpointHandlerRegistry(registry1)
registry2.add(item2, args_type, handler)

entries = registry2.get_all_entries()
self.assertLen(entries, 2)
self.assertIsInstance(
entries[(item1, args_type)],
handler,
)
self.assertIsInstance(
entries[(item2, args_type)],
handler,
)

@parameterized.product(
item=(None, 'item'),
)
def test_raise_error_when_no_entry_found(self, item: Optional[str]):
registry = DefaultCheckpointHandlerRegistry()

with self.assertRaisesRegex(
NoEntryError,
r'No entry for item=.* and args_ty=.* in the registry',
):
registry.get(item, _TestArgs)

def test_concrete_item_takes_precedence_over_general_args_type(self):
none_item = None
item = 'item'
args_type = _TestArgs

class _TestCheckpointHandlerA(_TestCheckpointHandler):
pass

class _TestCheckpointHandlerB(_TestCheckpointHandler):
pass

registry = DefaultCheckpointHandlerRegistry()
registry.add(none_item, args_type, _TestCheckpointHandlerA)
registry.add(item, args_type, _TestCheckpointHandlerB)

self.assertTrue(registry.has(none_item, args_type))
self.assertTrue(registry.has(item, args_type))
self.assertIsInstance(
registry.get(none_item, args_type),
_TestCheckpointHandlerA,
)
self.assertIsInstance(
registry.get(item, args_type),
_TestCheckpointHandlerB,
)

def test_falls_back_to_general_args_type(self):
none_item = None
registered_item = 'registered_item'
item_without_registration = 'item_without_registration'
args_type = _TestArgs

class _TestCheckpointHandlerA(_TestCheckpointHandler):
pass

class _TestCheckpointHandlerB(_TestCheckpointHandler):
pass

registry = DefaultCheckpointHandlerRegistry()
registry.add(none_item, args_type, _TestCheckpointHandlerA)
registry.add(registered_item, args_type, _TestCheckpointHandlerB)

self.assertTrue(registry.has(none_item, args_type))
self.assertTrue(registry.has(registered_item, args_type))
self.assertFalse(registry.has(item_without_registration, args_type))

self.assertIsInstance(
registry.get(none_item, args_type),
_TestCheckpointHandlerA,
)
self.assertIsInstance(
registry.get(item_without_registration, args_type),
_TestCheckpointHandlerA,
)
self.assertIsInstance(
registry.get(registered_item, args_type),
_TestCheckpointHandlerB,
)

def test_multiple_handlers_for_same_item(self):
item = 'item'

class _TestArgsA(checkpoint_args.CheckpointArgs):
pass

class _TestArgsB(checkpoint_args.CheckpointArgs):
pass

registry = DefaultCheckpointHandlerRegistry()
registry.add(item, _TestArgsA, _TestCheckpointHandler)
registry.add(item, _TestArgsB, _TestCheckpointHandler)

self.assertIsInstance(
registry.get(item, _TestArgsA),
_TestCheckpointHandler,
)
self.assertIsInstance(
registry.get(item, _TestArgsB),
_TestCheckpointHandler,
)

if __name__ == '__main__':
absltest.main()

0 comments on commit 253636c

Please sign in to comment.