55
66from keras .src import backend
77from keras .src import ops
8+ from keras .src import tree
89from keras .src .api_export import keras_export
910from keras .src .callbacks .monitor_callback import (
1011 MonitorCallback , # For metric monitoring logic
1112)
13+ from keras .src .distribution .distribution_lib import process_id
1214from keras .src .utils .io_utils import print_msg
13- from keras .src .utils .module_utils import LazyModule
14-
15- ocp = LazyModule (
16- "orbax.checkpoint" ,
17- pip_name = "orbax-checkpoint" ,
18- import_error_msg = (
19- "OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
20- "Install it with: pip install orbax-checkpoint"
21- ),
22- )
23-
24- # Note: Advanced Orbax functionality is available through the ocp LazyModule
25- # Users can access it via: from keras.src.utils.module_utils import LazyModule
26- # ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager
15+ from keras .src .utils .module_utils import ocp
2716
2817
2918def _get_state_tree (model ):
@@ -38,68 +27,49 @@ def convert_scalars(obj):
3827 elif isinstance (obj , np .generic ):
3928 # Convert numpy scalar types (like np.float32) to Python types
4029 return obj .item ()
41- elif isinstance (obj , dict ):
42- return {k : convert_scalars (v ) for k , v in obj .items ()}
4330 else :
4431 return obj
4532
46- return convert_scalars ( state_tree )
33+ return tree . map_structure ( convert_scalars , state_tree )
4734
4835
4936def _flatten_state_tree_values (state_tree ):
5037 """Flatten nested state tree into a list of values in consistent order."""
51- values = []
52-
53- def _flatten (obj ):
54- if isinstance (obj , dict ):
55- for key in sorted (obj .keys ()): # Sort for consistent ordering
56- _flatten (obj [key ])
57- else :
58- # Save any non-dict value (numpy arrays, lists, scalars, etc.)
59- values .append (obj )
60-
61- _flatten (state_tree )
62- return values
38+ return tree .flatten (state_tree )
6339
6440
6541def _reconstruct_state_tree_with_values (structure , values ):
6642 """Reconstruct state tree structure with provided values."""
6743 value_iter = iter (values )
6844
69- def _reconstruct (obj ):
70- if isinstance (obj , dict ):
71- new_dict = {}
72- for key in sorted (obj .keys ()):
73- new_dict [key ] = _reconstruct (obj [key ])
74- return new_dict
75- else :
76- value = next (value_iter )
77- # Handle different cases for value conversion
78- if isinstance (obj , np .generic ):
79- # obj is a numpy scalar (0-dimensional)
80- if isinstance (value , (int , float )):
81- # Convert Python scalar to numpy scalar
82- return np .array (value , dtype = obj .dtype )
83- elif isinstance (value , np .ndarray ):
84- # value is a numpy array, convert to scalar if needed
85- if value .ndim == 0 :
86- return np .array (value .item (), dtype = obj .dtype )
87- elif value .ndim == 1 and value .size == 1 :
88- return np .array (value .item (), dtype = obj .dtype )
89- else :
90- return value .astype (obj .dtype ).reshape (obj .shape )
45+ def _reconstruct_value (obj ):
46+ value = next (value_iter )
47+ # Handle different cases for value conversion
48+ if isinstance (obj , np .generic ):
49+ # obj is a numpy scalar (0-dimensional)
50+ if isinstance (value , (int , float )):
51+ # Convert Python scalar to numpy scalar
52+ return np .array (value , dtype = obj .dtype )
53+ elif isinstance (value , np .ndarray ):
54+ # value is a numpy array, convert to scalar if needed
55+ if value .ndim == 0 :
56+ return np .array (value .item (), dtype = obj .dtype )
57+ elif value .ndim == 1 and value .size == 1 :
58+ return np .array (value .item (), dtype = obj .dtype )
9159 else :
92- return np .array (value , dtype = obj .dtype )
93- elif isinstance (obj , np .ndarray ):
94- # obj is a numpy array
95- if isinstance (value , np .ndarray ):
9660 return value .astype (obj .dtype ).reshape (obj .shape )
97- else :
98- return np .array (value , dtype = obj .dtype ).reshape (obj .shape )
9961 else :
100- return value
62+ return np .array (value , dtype = obj .dtype )
63+ elif isinstance (obj , np .ndarray ):
64+ # obj is a numpy array
65+ if isinstance (value , np .ndarray ):
66+ return value .astype (obj .dtype ).reshape (obj .shape )
67+ else :
68+ return np .array (value , dtype = obj .dtype ).reshape (obj .shape )
69+ else :
70+ return value
10171
102- return _reconstruct ( structure )
72+ return tree . map_structure ( _reconstruct_value , structure )
10373
10474
10575def _restore_legacy_format (
@@ -327,7 +297,7 @@ def __init__(
327297 save_decision_policy = save_decision_policy ,
328298 )
329299 # Ensure directory exists (only needed on one process in multi-host)
330- if backend . get_process_index () == 0 :
300+ if process_id () == 0 :
331301 os .makedirs (directory , exist_ok = True )
332302
333303 # Create the CheckpointManager
@@ -380,38 +350,27 @@ def _save_checkpoint(self, step, logs=None):
380350 state_tree = _get_state_tree (self .model )
381351
382352 if state_tree is None :
383- if self .verbose > 0 :
384- print_msg (
385- "OrbaxCheckpoint: Skipping save due to state tree error"
386- )
387- return
388-
389- # Flatten the trainable variables values for cross-model compatibility
390- trainable_values = _flatten_state_tree_values (
391- state_tree ["trainable_variables" ]
392- )
393-
394- # Save optimizer and metrics state if requested
395- optimizer_values = None
396- if self .save_optimizer_state and "optimizer_variables" in state_tree :
397- optimizer_values = _flatten_state_tree_values (
398- state_tree ["optimizer_variables" ]
399- )
400-
401- metrics_values = None
402- if self .save_metrics_state and "metrics_variables" in state_tree :
403- metrics_values = _flatten_state_tree_values (
404- state_tree ["metrics_variables" ]
353+ raise RuntimeError (
354+ "OrbaxCheckpoint: Failed to get model state tree. "
355+ "The model may not be properly built or may have no "
356+ "savable state."
405357 )
406358
359+ # Save the nested state structures directly (preserving layer
360+ # names and structure)
407361 composite_state = {
408- "model_weights " : trainable_values ,
362+ "trainable_variables " : state_tree [ "trainable_variables" ] ,
409363 }
410364
411- if optimizer_values is not None :
412- composite_state ["optimizer_state" ] = optimizer_values
413- if metrics_values is not None :
414- composite_state ["metrics_variables" ] = metrics_values
365+ if self .save_optimizer_state and "optimizer_variables" in state_tree :
366+ composite_state ["optimizer_variables" ] = state_tree [
367+ "optimizer_variables"
368+ ]
369+
370+ if self .save_metrics_state and "metrics_variables" in state_tree :
371+ composite_state ["metrics_variables" ] = state_tree [
372+ "metrics_variables"
373+ ]
415374
416375 # Add metadata if specified
417376 if self .save_metadata is not None :
@@ -435,7 +394,7 @@ def _save_checkpoint(self, step, logs=None):
435394
436395 # --- Save Logic ---
437396 # Only save on the primary process (rank 0) in distributed setups
438- is_primary_host = backend . get_process_index () == 0
397+ is_primary_host = process_id () == 0
439398
440399 if is_primary_host :
441400 if self .verbose > 0 :
@@ -540,7 +499,7 @@ def load_checkpoint(self, step, model=None):
540499 data iterator state dict if available, None otherwise.
541500 """
542501 # In distributed training, only load on primary process
543- if backend . get_process_index () != 0 :
502+ if process_id () != 0 :
544503 return True # Return True to indicate no error, but no loading
545504
546505 if self .verbose > 0 :
@@ -594,11 +553,18 @@ def _restore_model_state(self, checkpoint_data, model=None):
594553 """
595554 target_model = model if model is not None else self .model
596555
597- # Check if this is the new flattened format
598- if "model_weights" in checkpoint_data and isinstance (
556+ # Check if this is the new nested structure format
557+ if "trainable_variables" in checkpoint_data and isinstance (
558+ checkpoint_data ["trainable_variables" ], dict
559+ ):
560+ # New format: nested structures
561+ return self ._restore_from_nested_structures (
562+ checkpoint_data , target_model
563+ )
564+ elif "model_weights" in checkpoint_data and isinstance (
599565 checkpoint_data ["model_weights" ], list
600566 ):
601- # New format: flattened values
567+ # Old format: flattened values (for backward compatibility)
602568 return self ._restore_from_flattened_values (
603569 checkpoint_data , target_model
604570 )
@@ -617,8 +583,109 @@ def _restore_model_state(self, checkpoint_data, model=None):
617583 )
618584 return True
619585
586+ def _restore_from_nested_structures (self , checkpoint_data , target_model ):
587+ """Restore from the new nested structures format."""
588+ # Ensure the target model is built so it has variables
589+ if len (target_model .trainable_variables ) == 0 :
590+ try :
591+ # Try to build the model by doing a dummy forward pass
592+ if (
593+ hasattr (target_model , "input_shape" )
594+ and target_model .input_shape is not None
595+ ):
596+ dummy_input_shape = target_model .input_shape
597+ if dummy_input_shape [0 ] is None : # Batch dimension is None
598+ dummy_input = np .zeros ((1 ,) + dummy_input_shape [1 :])
599+ else :
600+ dummy_input = np .zeros (dummy_input_shape )
601+ target_model (dummy_input )
602+ except Exception :
603+ # If dummy forward pass fails, try build
604+ try :
605+ if (
606+ hasattr (target_model , "input_shape" )
607+ and target_model .input_shape is not None
608+ ):
609+ build_shape = target_model .input_shape
610+ if (
611+ isinstance (build_shape , (list , tuple ))
612+ and len (build_shape ) > 1
613+ and build_shape [0 ] is None
614+ ):
615+ build_shape = build_shape [1 :]
616+ target_model .build (build_shape )
617+ except Exception :
618+ # If building fails, continue anyway
619+ pass
620+
621+ # Prepare the state tree to restore
622+ reconstructed_state = {}
623+
624+ # Restore trainable variables
625+ if "trainable_variables" in checkpoint_data :
626+ reconstructed_state ["trainable_variables" ] = checkpoint_data [
627+ "trainable_variables"
628+ ]
629+
630+ # Restore optimizer variables if available and model has optimizer
631+ if (
632+ "optimizer_variables" in checkpoint_data
633+ and self .save_optimizer_state
634+ and hasattr (target_model , "optimizer" )
635+ and target_model .optimizer is not None
636+ ):
637+ reconstructed_state ["optimizer_variables" ] = checkpoint_data [
638+ "optimizer_variables"
639+ ]
640+
641+ # Restore metrics variables if available
642+ if "metrics_variables" in checkpoint_data and self .save_metrics_state :
643+ reconstructed_state ["metrics_variables" ] = checkpoint_data [
644+ "metrics_variables"
645+ ]
646+
647+ # Use set_state_tree to restore the state
648+ target_model .set_state_tree (reconstructed_state )
649+
650+ if self .verbose > 0 :
651+ print_msg ("OrbaxCheckpoint: Successfully restored model state" )
652+ return True
653+
620654 def _restore_from_flattened_values (self , checkpoint_data , target_model ):
621655 """Restore from the new flattened values format."""
656+ # Ensure the target model is built so it has variables
657+ if len (target_model .trainable_variables ) == 0 :
658+ try :
659+ # Try to build the model by doing a dummy forward pass
660+ if (
661+ hasattr (target_model , "input_shape" )
662+ and target_model .input_shape is not None
663+ ):
664+ dummy_input_shape = target_model .input_shape
665+ if dummy_input_shape [0 ] is None : # Batch dimension is None
666+ dummy_input = np .zeros ((1 ,) + dummy_input_shape [1 :])
667+ else :
668+ dummy_input = np .zeros (dummy_input_shape )
669+ target_model (dummy_input )
670+ except Exception :
671+ # If dummy forward pass fails, try build
672+ try :
673+ if (
674+ hasattr (target_model , "input_shape" )
675+ and target_model .input_shape is not None
676+ ):
677+ build_shape = target_model .input_shape
678+ if (
679+ isinstance (build_shape , (list , tuple ))
680+ and len (build_shape ) > 1
681+ and build_shape [0 ] is None
682+ ):
683+ build_shape = build_shape [1 :]
684+ target_model .build (build_shape )
685+ except Exception :
686+ # If building fails, continue anyway
687+ pass
688+
622689 # Get the target model's state tree structure (without convert_scalars)
623690 target_state_tree = target_model .get_state_tree (
624691 value_format = "numpy_array"
0 commit comments