Skip to content

Commit fef84a0

Browse files
Optimize Orbax checkpoint for JAX backend with compatibility check
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available - Fall back to numpy conversion for older JAX versions that don't have record_scalar - Maintain cross-backend compatibility while avoiding unnecessary conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism
1 parent b7a0dff commit fef84a0

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
def _get_state_tree(model):
1818
"""Get the complete model state as a nested tree structure."""
19-
# For JAX backend, preserve native arrays to avoid unnecessary conversions
20-
# For other backends, convert to numpy arrays
19+
# For JAX backend, preserve native arrays if JAX monitoring is available
20+
# to avoid unnecessary conversions. Otherwise convert to numpy for compatibility.
2121
if backend.backend() == "jax":
22-
state_tree = model.get_state_tree()
22+
try:
23+
import jax
24+
# Check if jax.monitoring.record_scalar exists (introduced in JAX 0.7.0)
25+
jax.monitoring.record_scalar
26+
state_tree = model.get_state_tree()
27+
except (ImportError, AttributeError):
28+
# Fallback to numpy conversion for older JAX versions
29+
state_tree = model.get_state_tree(value_format="numpy_array")
2330
else:
2431
state_tree = model.get_state_tree(value_format="numpy_array")
2532

0 commit comments

Comments
 (0)