|
21 | 21 | ), |
22 | 22 | ) |
23 | 23 |
|
24 | | -# Expose advanced Orbax functionality for users who need direct access |
25 | | -# These are provided as bridge for advanced usecases like custom type handlers |
26 | | -CheckpointManager = ocp.CheckpointManager |
27 | | -SaveArgs = ocp.SaveArgs |
28 | | -StandardRestore = ocp.args.StandardRestore |
29 | | - |
30 | | -# Type handler functionality for custom serialization |
31 | | -TypeHandler = ocp.type_handlers.TypeHandler |
32 | | -register_type_handler = ocp.type_handlers.register_type_handler |
33 | | - |
34 | | -# Direct checkpointing for custom objects |
35 | | -PyTreeCheckpointer = ocp.PyTreeCheckpointer |
36 | | - |
37 | | -# Metadata functionality |
38 | | -metadata = ocp.metadata |
| 24 | +# Note: Advanced Orbax functionality is available through the ocp LazyModule |
| 25 | +# Users can access it via: from keras.src.utils.module_utils import LazyModule |
| 26 | +# ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager |
39 | 27 |
|
40 | 28 |
|
41 | 29 | def _get_state_tree(model): |
@@ -701,3 +689,14 @@ def _restore_from_state_tree(self, state_tree, target_model): |
701 | 689 | if self.verbose > 0: |
702 | 690 | print_msg("OrbaxCheckpoint: Successfully restored model state") |
703 | 691 | return True |
| 692 | + |
| 693 | + |
| 694 | +# Export additional Orbax functionality for advanced users (only if available) |
| 695 | +if ocp.available: |
| 696 | + CheckpointManager = ocp.CheckpointManager |
| 697 | + PyTreeCheckpointer = ocp.PyTreeCheckpointer |
| 698 | + SaveArgs = ocp.SaveArgs |
| 699 | + StandardRestore = ocp.args.StandardRestore |
| 700 | + TypeHandler = ocp.type_handlers.TypeHandler |
| 701 | + metadata = ocp.metadata |
| 702 | + register_type_handler = ocp.type_handlers.register_type_handler |
0 commit comments