4141def _get_state_tree (model ):
4242 """Get the complete model state as a nested tree structure."""
4343 state_tree = model .get_state_tree (value_format = "numpy_array" )
44-
44+
4545 # Convert numpy scalar types to Python types for Orbax compatibility
4646 def convert_scalars (obj ):
4747 if isinstance (obj , np .ndarray ) and obj .ndim == 0 :
@@ -54,29 +54,30 @@ def convert_scalars(obj):
5454 return {k : convert_scalars (v ) for k , v in obj .items ()}
5555 else :
5656 return obj
57-
57+
5858 return convert_scalars (state_tree )
5959
6060
6161def _flatten_state_tree_values (state_tree ):
6262 """Flatten nested state tree into a list of values in consistent order."""
6363 values = []
64+
6465 def _flatten (obj ):
6566 if isinstance (obj , dict ):
6667 for key in sorted (obj .keys ()): # Sort for consistent ordering
6768 _flatten (obj [key ])
6869 else :
6970 # Save any non-dict value (numpy arrays, lists, scalars, etc.)
7071 values .append (obj )
72+
7173 _flatten (state_tree )
7274 return values
7375
7476
7577def _reconstruct_state_tree_with_values (structure , values ):
7678 """Reconstruct state tree structure with provided values."""
77- result = {}
7879 value_iter = iter (values )
79-
80+
8081 def _reconstruct (obj ):
8182 if isinstance (obj , dict ):
8283 new_dict = {}
@@ -109,7 +110,7 @@ def _reconstruct(obj):
109110 return np .array (value , dtype = obj .dtype ).reshape (obj .shape )
110111 else :
111112 return value
112-
113+
113114 return _reconstruct (structure )
114115
115116
@@ -128,15 +129,10 @@ def _restore_legacy_format(
128129 target_model .weights [i ].assign (weight_tensor )
129130
130131 # Restore optimizer state if available
131- if (
132- "optimizer_state" in checkpoint_data
133- and save_optimizer_state
134- ):
132+ if "optimizer_state" in checkpoint_data and save_optimizer_state :
135133 optimizer_vars_np = checkpoint_data ["optimizer_state" ]
136134 # Only restore if the variable counts match
137- if len (optimizer_vars_np ) == len (
138- target_model .optimizer .variables
139- ):
135+ if len (optimizer_vars_np ) == len (target_model .optimizer .variables ):
140136 # Convert NumPy arrays back to backend tensors and assign to
141137 # optimizer
142138 for i , var_np in enumerate (optimizer_vars_np ):
@@ -406,14 +402,14 @@ def _save_checkpoint(self, step, logs=None):
406402 trainable_values = _flatten_state_tree_values (
407403 state_tree ["trainable_variables" ]
408404 )
409-
405+
410406 # Save optimizer and metrics state if requested
411407 optimizer_values = None
412408 if self .save_optimizer_state and "optimizer_variables" in state_tree :
413409 optimizer_values = _flatten_state_tree_values (
414410 state_tree ["optimizer_variables" ]
415411 )
416-
412+
417413 metrics_values = None
418414 if self .save_metrics_state and "metrics_variables" in state_tree :
419415 metrics_values = _flatten_state_tree_values (
@@ -423,7 +419,7 @@ def _save_checkpoint(self, step, logs=None):
423419 composite_state = {
424420 "model_weights" : trainable_values ,
425421 }
426-
422+
427423 if optimizer_values is not None :
428424 composite_state ["optimizer_state" ] = optimizer_values
429425 if metrics_values is not None :
@@ -611,8 +607,9 @@ def _restore_model_state(self, checkpoint_data, model=None):
611607 target_model = model if model is not None else self .model
612608
613609 # Check if this is the new flattened format
614- if ("model_weights" in checkpoint_data and
615- isinstance (checkpoint_data ["model_weights" ], list )):
610+ if "model_weights" in checkpoint_data and isinstance (
611+ checkpoint_data ["model_weights" ], list
612+ ):
616613 # New format: flattened values
617614 return self ._restore_from_flattened_values (
618615 checkpoint_data , target_model
@@ -625,8 +622,10 @@ def _restore_model_state(self, checkpoint_data, model=None):
625622 else :
626623 # Fallback to legacy format
627624 _restore_legacy_format (
628- checkpoint_data , target_model , self .save_optimizer_state ,
629- self .save_metrics_state
625+ checkpoint_data ,
626+ target_model ,
627+ self .save_optimizer_state ,
628+ self .save_metrics_state ,
630629 )
631630 return True
632631
@@ -649,9 +648,9 @@ def _restore_from_flattened_values(self, checkpoint_data, target_model):
649648 # Restore trainable variables
650649 if "model_weights" in checkpoint_data :
651650 saved_trainable_values = checkpoint_data ["model_weights" ]
652- target_trainable_structure = (
653- target_state_tree [ "trainable_variables" ]
654- )
651+ target_trainable_structure = target_state_tree [
652+ "trainable_variables"
653+ ]
655654 reconstructed_state ["trainable_variables" ] = (
656655 _reconstruct_state_tree_with_values (
657656 target_trainable_structure , saved_trainable_values
@@ -665,9 +664,9 @@ def _restore_from_flattened_values(self, checkpoint_data, target_model):
665664 and "optimizer_variables" in target_state_tree
666665 ):
667666 saved_optimizer_values = checkpoint_data ["optimizer_state" ]
668- target_optimizer_structure = (
669- target_state_tree [ "optimizer_variables" ]
670- )
667+ target_optimizer_structure = target_state_tree [
668+ "optimizer_variables"
669+ ]
671670 reconstructed_state ["optimizer_variables" ] = (
672671 _reconstruct_state_tree_with_values (
673672 target_optimizer_structure , saved_optimizer_values
@@ -702,5 +701,3 @@ def _restore_from_state_tree(self, state_tree, target_model):
702701 if self .verbose > 0 :
703702 print_msg ("OrbaxCheckpoint: Successfully restored model state" )
704703 return True
705-
706-
0 commit comments