Skip to content

Commit 33f4e66

Browse files
Optimize Orbax checkpoint for JAX backend with compatibility check
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available - Fall back to numpy conversion for older JAX versions that don't have record_scalar - Maintain cross-backend compatibility while avoiding unnecessary conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism
1 parent b7a0dff commit 33f4e66

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,18 @@
1616

1717
def _get_state_tree(model):
1818
"""Get the complete model state as a nested tree structure."""
19-
# For JAX backend, preserve native arrays to avoid unnecessary conversions
20-
# For other backends, convert to numpy arrays
19+
# For JAX backend, preserve native arrays if JAX monitoring available
20+
# to avoid unnecessary conversions. Otherwise convert to numpy.
2121
if backend.backend() == "jax":
22-
state_tree = model.get_state_tree()
22+
try:
23+
import jax
24+
25+
# Check if jax.monitoring.record_scalar exists (JAX 0.7.0+)
26+
jax.monitoring.record_scalar
27+
state_tree = model.get_state_tree()
28+
except (ImportError, AttributeError):
29+
# Fallback to numpy conversion for older JAX versions
30+
state_tree = model.get_state_tree(value_format="numpy_array")
2331
else:
2432
state_tree = model.get_state_tree(value_format="numpy_array")
2533

0 commit comments

Comments
 (0)