You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Optimize Orbax checkpoint for JAX backend with compatibility check
- Preserve JAX arrays during saving when jax.monitoring.record_scalar is available
- Fall back to numpy conversion for older JAX versions that don't have record_scalar
- Maintain cross-backend compatibility while avoiding unnecessary conversions
- Update async waiting to use CheckpointManager.wait_until_finished()
- Implement AlwaysSavePolicy for reliable save decisions
- Add expected failures for sklearn tests due to neural network non-determinism
0 commit comments