Skip to content

Commit b56dc7b

Browse files
Improve OrbaxCheckpoint: preserve nested structures, enhance tests
- Preserve nested state tree structures instead of flattening for better layer name preservation - Add backward compatibility for old flattened format checkpoints - Simplify test class by using self.get_temp_dir() instead of setUp/tearDown - Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests - Move process_id function from backend to distribution module - Update imports to use centralized LazyModule for orbax.checkpoint - Test across all backends (JAX, TensorFlow, PyTorch) - all passing
1 parent 19d2495 commit b56dc7b

File tree

7 files changed

+301
-230
lines changed

7 files changed

+301
-230
lines changed

keras/api/_tf_keras/keras/distribution/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.distribution.distribution_lib import distribution as distribution
1818
from keras.src.distribution.distribution_lib import initialize as initialize
1919
from keras.src.distribution.distribution_lib import list_devices as list_devices
20+
from keras.src.distribution.distribution_lib import process_id as process_id
2021
from keras.src.distribution.distribution_lib import (
2122
set_distribution as set_distribution,
2223
)

keras/api/distribution/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.src.distribution.distribution_lib import distribution as distribution
1818
from keras.src.distribution.distribution_lib import initialize as initialize
1919
from keras.src.distribution.distribution_lib import list_devices as list_devices
20+
from keras.src.distribution.distribution_lib import process_id as process_id
2021
from keras.src.distribution.distribution_lib import (
2122
set_distribution as set_distribution,
2223
)

keras/src/backend/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -75,38 +75,3 @@ class name_scope(backend_name_scope):
7575
@keras_export("keras.device")
7676
def device(device_name):
7777
return device_scope(device_name) # noqa: F405
78-
79-
80-
def get_process_index():
81-
"""Get the index of the current process in a distributed setup.
82-
83-
Returns:
84-
int: The process index (0 for primary process, >0 for others).
85-
Returns 0 if not in a distributed setup.
86-
"""
87-
backend_name = backend()
88-
if backend_name == "jax":
89-
try:
90-
import jax
91-
92-
return jax.process_index()
93-
except (ImportError, AttributeError):
94-
return 0
95-
elif backend_name == "tensorflow":
96-
try:
97-
import tensorflow as tf
98-
99-
return tf.distribute.get_replica_context().replica_id_in_sync_group
100-
except (ImportError, AttributeError, RuntimeError):
101-
return 0
102-
elif backend_name == "torch":
103-
try:
104-
import torch.distributed as dist
105-
106-
if dist.is_available() and dist.is_initialized():
107-
return dist.get_rank()
108-
return 0
109-
except (ImportError, AttributeError):
110-
return 0
111-
else:
112-
return 0

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 159 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,14 @@
55

66
from keras.src import backend
77
from keras.src import ops
8+
from keras.src import tree
89
from keras.src.api_export import keras_export
910
from keras.src.callbacks.monitor_callback import (
1011
MonitorCallback, # For metric monitoring logic
1112
)
13+
from keras.src.distribution.distribution_lib import process_id
1214
from keras.src.utils.io_utils import print_msg
13-
from keras.src.utils.module_utils import LazyModule
14-
15-
ocp = LazyModule(
16-
"orbax.checkpoint",
17-
pip_name="orbax-checkpoint",
18-
import_error_msg=(
19-
"OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
20-
"Install it with: pip install orbax-checkpoint"
21-
),
22-
)
23-
24-
# Note: Advanced Orbax functionality is available through the ocp LazyModule
25-
# Users can access it via: from keras.src.utils.module_utils import LazyModule
26-
# ocp = LazyModule("orbax.checkpoint"); ocp.CheckpointManager
15+
from keras.src.utils.module_utils import ocp
2716

2817

2918
def _get_state_tree(model):
@@ -38,68 +27,49 @@ def convert_scalars(obj):
3827
elif isinstance(obj, np.generic):
3928
# Convert numpy scalar types (like np.float32) to Python types
4029
return obj.item()
41-
elif isinstance(obj, dict):
42-
return {k: convert_scalars(v) for k, v in obj.items()}
4330
else:
4431
return obj
4532

46-
return convert_scalars(state_tree)
33+
return tree.map_structure(convert_scalars, state_tree)
4734

4835

4936
def _flatten_state_tree_values(state_tree):
5037
"""Flatten nested state tree into a list of values in consistent order."""
51-
values = []
52-
53-
def _flatten(obj):
54-
if isinstance(obj, dict):
55-
for key in sorted(obj.keys()): # Sort for consistent ordering
56-
_flatten(obj[key])
57-
else:
58-
# Save any non-dict value (numpy arrays, lists, scalars, etc.)
59-
values.append(obj)
60-
61-
_flatten(state_tree)
62-
return values
38+
return tree.flatten(state_tree)
6339

6440

6541
def _reconstruct_state_tree_with_values(structure, values):
6642
"""Reconstruct state tree structure with provided values."""
6743
value_iter = iter(values)
6844

69-
def _reconstruct(obj):
70-
if isinstance(obj, dict):
71-
new_dict = {}
72-
for key in sorted(obj.keys()):
73-
new_dict[key] = _reconstruct(obj[key])
74-
return new_dict
75-
else:
76-
value = next(value_iter)
77-
# Handle different cases for value conversion
78-
if isinstance(obj, np.generic):
79-
# obj is a numpy scalar (0-dimensional)
80-
if isinstance(value, (int, float)):
81-
# Convert Python scalar to numpy scalar
82-
return np.array(value, dtype=obj.dtype)
83-
elif isinstance(value, np.ndarray):
84-
# value is a numpy array, convert to scalar if needed
85-
if value.ndim == 0:
86-
return np.array(value.item(), dtype=obj.dtype)
87-
elif value.ndim == 1 and value.size == 1:
88-
return np.array(value.item(), dtype=obj.dtype)
89-
else:
90-
return value.astype(obj.dtype).reshape(obj.shape)
45+
def _reconstruct_value(obj):
46+
value = next(value_iter)
47+
# Handle different cases for value conversion
48+
if isinstance(obj, np.generic):
49+
# obj is a numpy scalar (0-dimensional)
50+
if isinstance(value, (int, float)):
51+
# Convert Python scalar to numpy scalar
52+
return np.array(value, dtype=obj.dtype)
53+
elif isinstance(value, np.ndarray):
54+
# value is a numpy array, convert to scalar if needed
55+
if value.ndim == 0:
56+
return np.array(value.item(), dtype=obj.dtype)
57+
elif value.ndim == 1 and value.size == 1:
58+
return np.array(value.item(), dtype=obj.dtype)
9159
else:
92-
return np.array(value, dtype=obj.dtype)
93-
elif isinstance(obj, np.ndarray):
94-
# obj is a numpy array
95-
if isinstance(value, np.ndarray):
9660
return value.astype(obj.dtype).reshape(obj.shape)
97-
else:
98-
return np.array(value, dtype=obj.dtype).reshape(obj.shape)
9961
else:
100-
return value
62+
return np.array(value, dtype=obj.dtype)
63+
elif isinstance(obj, np.ndarray):
64+
# obj is a numpy array
65+
if isinstance(value, np.ndarray):
66+
return value.astype(obj.dtype).reshape(obj.shape)
67+
else:
68+
return np.array(value, dtype=obj.dtype).reshape(obj.shape)
69+
else:
70+
return value
10171

102-
return _reconstruct(structure)
72+
return tree.map_structure(_reconstruct_value, structure)
10373

10474

10575
def _restore_legacy_format(
@@ -327,7 +297,7 @@ def __init__(
327297
save_decision_policy=save_decision_policy,
328298
)
329299
# Ensure directory exists (only needed on one process in multi-host)
330-
if backend.get_process_index() == 0:
300+
if process_id() == 0:
331301
os.makedirs(directory, exist_ok=True)
332302

333303
# Create the CheckpointManager
@@ -380,38 +350,27 @@ def _save_checkpoint(self, step, logs=None):
380350
state_tree = _get_state_tree(self.model)
381351

382352
if state_tree is None:
383-
if self.verbose > 0:
384-
print_msg(
385-
"OrbaxCheckpoint: Skipping save due to state tree error"
386-
)
387-
return
388-
389-
# Flatten the trainable variables values for cross-model compatibility
390-
trainable_values = _flatten_state_tree_values(
391-
state_tree["trainable_variables"]
392-
)
393-
394-
# Save optimizer and metrics state if requested
395-
optimizer_values = None
396-
if self.save_optimizer_state and "optimizer_variables" in state_tree:
397-
optimizer_values = _flatten_state_tree_values(
398-
state_tree["optimizer_variables"]
399-
)
400-
401-
metrics_values = None
402-
if self.save_metrics_state and "metrics_variables" in state_tree:
403-
metrics_values = _flatten_state_tree_values(
404-
state_tree["metrics_variables"]
353+
raise RuntimeError(
354+
"OrbaxCheckpoint: Failed to get model state tree. "
355+
"The model may not be properly built or may have no "
356+
"savable state."
405357
)
406358

359+
# Save the nested state structures directly (preserving layer
360+
# names and structure)
407361
composite_state = {
408-
"model_weights": trainable_values,
362+
"trainable_variables": state_tree["trainable_variables"],
409363
}
410364

411-
if optimizer_values is not None:
412-
composite_state["optimizer_state"] = optimizer_values
413-
if metrics_values is not None:
414-
composite_state["metrics_variables"] = metrics_values
365+
if self.save_optimizer_state and "optimizer_variables" in state_tree:
366+
composite_state["optimizer_variables"] = state_tree[
367+
"optimizer_variables"
368+
]
369+
370+
if self.save_metrics_state and "metrics_variables" in state_tree:
371+
composite_state["metrics_variables"] = state_tree[
372+
"metrics_variables"
373+
]
415374

416375
# Add metadata if specified
417376
if self.save_metadata is not None:
@@ -435,7 +394,7 @@ def _save_checkpoint(self, step, logs=None):
435394

436395
# --- Save Logic ---
437396
# Only save on the primary process (rank 0) in distributed setups
438-
is_primary_host = backend.get_process_index() == 0
397+
is_primary_host = process_id() == 0
439398

440399
if is_primary_host:
441400
if self.verbose > 0:
@@ -540,7 +499,7 @@ def load_checkpoint(self, step, model=None):
540499
data iterator state dict if available, None otherwise.
541500
"""
542501
# In distributed training, only load on primary process
543-
if backend.get_process_index() != 0:
502+
if process_id() != 0:
544503
return True # Return True to indicate no error, but no loading
545504

546505
if self.verbose > 0:
@@ -594,11 +553,18 @@ def _restore_model_state(self, checkpoint_data, model=None):
594553
"""
595554
target_model = model if model is not None else self.model
596555

597-
# Check if this is the new flattened format
598-
if "model_weights" in checkpoint_data and isinstance(
556+
# Check if this is the new nested structure format
557+
if "trainable_variables" in checkpoint_data and isinstance(
558+
checkpoint_data["trainable_variables"], dict
559+
):
560+
# New format: nested structures
561+
return self._restore_from_nested_structures(
562+
checkpoint_data, target_model
563+
)
564+
elif "model_weights" in checkpoint_data and isinstance(
599565
checkpoint_data["model_weights"], list
600566
):
601-
# New format: flattened values
567+
# Old format: flattened values (for backward compatibility)
602568
return self._restore_from_flattened_values(
603569
checkpoint_data, target_model
604570
)
@@ -617,8 +583,109 @@ def _restore_model_state(self, checkpoint_data, model=None):
617583
)
618584
return True
619585

586+
def _restore_from_nested_structures(self, checkpoint_data, target_model):
587+
"""Restore from the new nested structures format."""
588+
# Ensure the target model is built so it has variables
589+
if len(target_model.trainable_variables) == 0:
590+
try:
591+
# Try to build the model by doing a dummy forward pass
592+
if (
593+
hasattr(target_model, "input_shape")
594+
and target_model.input_shape is not None
595+
):
596+
dummy_input_shape = target_model.input_shape
597+
if dummy_input_shape[0] is None: # Batch dimension is None
598+
dummy_input = np.zeros((1,) + dummy_input_shape[1:])
599+
else:
600+
dummy_input = np.zeros(dummy_input_shape)
601+
target_model(dummy_input)
602+
except Exception:
603+
# If dummy forward pass fails, try build
604+
try:
605+
if (
606+
hasattr(target_model, "input_shape")
607+
and target_model.input_shape is not None
608+
):
609+
build_shape = target_model.input_shape
610+
if (
611+
isinstance(build_shape, (list, tuple))
612+
and len(build_shape) > 1
613+
and build_shape[0] is None
614+
):
615+
build_shape = build_shape[1:]
616+
target_model.build(build_shape)
617+
except Exception:
618+
# If building fails, continue anyway
619+
pass
620+
621+
# Prepare the state tree to restore
622+
reconstructed_state = {}
623+
624+
# Restore trainable variables
625+
if "trainable_variables" in checkpoint_data:
626+
reconstructed_state["trainable_variables"] = checkpoint_data[
627+
"trainable_variables"
628+
]
629+
630+
# Restore optimizer variables if available and model has optimizer
631+
if (
632+
"optimizer_variables" in checkpoint_data
633+
and self.save_optimizer_state
634+
and hasattr(target_model, "optimizer")
635+
and target_model.optimizer is not None
636+
):
637+
reconstructed_state["optimizer_variables"] = checkpoint_data[
638+
"optimizer_variables"
639+
]
640+
641+
# Restore metrics variables if available
642+
if "metrics_variables" in checkpoint_data and self.save_metrics_state:
643+
reconstructed_state["metrics_variables"] = checkpoint_data[
644+
"metrics_variables"
645+
]
646+
647+
# Use set_state_tree to restore the state
648+
target_model.set_state_tree(reconstructed_state)
649+
650+
if self.verbose > 0:
651+
print_msg("OrbaxCheckpoint: Successfully restored model state")
652+
return True
653+
620654
def _restore_from_flattened_values(self, checkpoint_data, target_model):
621655
"""Restore from the new flattened values format."""
656+
# Ensure the target model is built so it has variables
657+
if len(target_model.trainable_variables) == 0:
658+
try:
659+
# Try to build the model by doing a dummy forward pass
660+
if (
661+
hasattr(target_model, "input_shape")
662+
and target_model.input_shape is not None
663+
):
664+
dummy_input_shape = target_model.input_shape
665+
if dummy_input_shape[0] is None: # Batch dimension is None
666+
dummy_input = np.zeros((1,) + dummy_input_shape[1:])
667+
else:
668+
dummy_input = np.zeros(dummy_input_shape)
669+
target_model(dummy_input)
670+
except Exception:
671+
# If dummy forward pass fails, try build
672+
try:
673+
if (
674+
hasattr(target_model, "input_shape")
675+
and target_model.input_shape is not None
676+
):
677+
build_shape = target_model.input_shape
678+
if (
679+
isinstance(build_shape, (list, tuple))
680+
and len(build_shape) > 1
681+
and build_shape[0] is None
682+
):
683+
build_shape = build_shape[1:]
684+
target_model.build(build_shape)
685+
except Exception:
686+
# If building fails, continue anyway
687+
pass
688+
622689
# Get the target model's state tree structure (without convert_scalars)
623690
target_state_tree = target_model.get_state_tree(
624691
value_format="numpy_array"

0 commit comments

Comments
 (0)