Skip to content

Commit 19d2495

Browse files
Add OrbaxCheckpoint callback with conditional exports and improved test handling
- Implement OrbaxCheckpoint callback for async checkpointing with state tree handling - Add conditional exports for optional orbax-checkpoint dependency - Use pytest.importorskip for clean optional dependency testing - Ensure graceful handling when orbax-checkpoint is not installed
1 parent 61bd5e6 commit 19d2495

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,9 @@
2121
),
2222
)
2323

24-
# Expose advanced Orbax functionality for users who need direct access
25-
# These are provided as bridge for advanced usecases like custom type handlers
26-
CheckpointManager = ocp.CheckpointManager
27-
SaveArgs = ocp.SaveArgs
28-
StandardRestore = ocp.args.StandardRestore
29-
30-
# Type handler functionality for custom serialization
31-
TypeHandler = ocp.type_handlers.TypeHandler
32-
register_type_handler = ocp.type_handlers.register_type_handler
33-
34-
# Direct checkpointing for custom objects
35-
PyTreeCheckpointer = ocp.PyTreeCheckpointer
36-
37-
# Metadata functionality
38-
metadata = ocp.metadata
24+
# Note: Advanced Orbax functionality is available through the ocp LazyModule
25+
# Users can access it via: from keras.src.utils.module_utils import LazyModule
26+
# ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager
3927

4028

4129
def _get_state_tree(model):
@@ -701,3 +689,14 @@ def _restore_from_state_tree(self, state_tree, target_model):
701689
if self.verbose > 0:
702690
print_msg("OrbaxCheckpoint: Successfully restored model state")
703691
return True
692+
693+
694+
# Export additional Orbax functionality for advanced users (only if available)
695+
if ocp.available:
696+
CheckpointManager = ocp.CheckpointManager
697+
PyTreeCheckpointer = ocp.PyTreeCheckpointer
698+
SaveArgs = ocp.SaveArgs
699+
StandardRestore = ocp.args.StandardRestore
700+
TypeHandler = ocp.type_handlers.TypeHandler
701+
metadata = ocp.metadata
702+
register_type_handler = ocp.type_handlers.register_type_handler

keras/src/callbacks/orbax_checkpoint_test.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,23 @@
1111
from keras.src import testing
1212

1313
# Import advanced Orbax functionality through the Keras bridge
14-
from keras.src.callbacks.orbax_checkpoint import CheckpointManager
14+
# These will only be available if orbax-checkpoint is installed
15+
try:
16+
from keras.src.callbacks.orbax_checkpoint import CheckpointManager
17+
from keras.src.callbacks.orbax_checkpoint import PyTreeCheckpointer
18+
from keras.src.callbacks.orbax_checkpoint import SaveArgs
19+
from keras.src.callbacks.orbax_checkpoint import StandardRestore
20+
from keras.src.callbacks.orbax_checkpoint import TypeHandler
21+
from keras.src.callbacks.orbax_checkpoint import metadata
22+
from keras.src.callbacks.orbax_checkpoint import register_type_handler
23+
except ImportError:
24+
# If orbax is not available, these won't be exported
25+
pass
26+
1527
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
16-
from keras.src.callbacks.orbax_checkpoint import PyTreeCheckpointer
17-
from keras.src.callbacks.orbax_checkpoint import SaveArgs
18-
from keras.src.callbacks.orbax_checkpoint import StandardRestore
19-
from keras.src.callbacks.orbax_checkpoint import TypeHandler
20-
from keras.src.callbacks.orbax_checkpoint import metadata
21-
from keras.src.callbacks.orbax_checkpoint import register_type_handler
28+
29+
# Skip the entire test module if orbax-checkpoint is not available
30+
pytest.importorskip("orbax.checkpoint")
2231

2332

2433
class OrbaxCheckpointTest(testing.TestCase):

0 commit comments

Comments
 (0)