Skip to content

Commit 7caff51

Browse files
Fix OrbaxCheckpoint sharding and multi-host issues
- Fix sharding parameter passing in save/restore operations by passing as kwargs instead of setting attributes on StandardSave/StandardRestore objects - Add robust error handling for distribution initialization with multiple error message patterns - Add proper test skipping for JAX-only features when distribution module unavailable - Add sharding parameter validation in constructor to prevent invalid types - Update test expectations to match corrected sharding validation behavior These changes ensure proper sharding support for JAX multi-host checkpointing while maintaining backward compatibility.
1 parent ece595d commit 7caff51

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,16 @@ def __init__(
345345
"sharding and multi_host parameters are only supported "
346346
"with JAX backend. Current backend: " + backend.backend()
347347
)
348+
349+
# Validate sharding object type
350+
if sharding is not None and backend.backend() == "jax":
351+
# Basic validation: sharding should not be a string or other
352+
# primitive type
353+
if isinstance(sharding, (str, int, float, bool)):
354+
raise TypeError(
355+
f"sharding parameter must be a valid JAX sharding object, "
356+
f"got {type(sharding).__name__}: {sharding}"
357+
)
348358
self._batches_seen_since_last_saving = 0
349359
self._last_batch_seen = 0
350360
self._current_epoch = 0 # Keep track of epoch
@@ -395,9 +405,14 @@ def __init__(
395405
except RuntimeError as e:
396406
# If distributed cannot be initialized (e.g., JAX already
397407
# initialized), continue anyway - the multi_host flag is mainly
398-
# a hint to Orbax
399-
if "must be called before" in str(e):
400-
pass # This is expected in test environments
408+
# a hint to Orbax.
409+
# We check for messages related to initialization state.
410+
error_str = str(e).lower()
411+
if (
412+
"already been initialized" in error_str
413+
or "must be called before" in error_str
414+
):
415+
pass # This is expected in some environments.
401416
else:
402417
raise
403418
# Orbax will automatically handle multi-host coordination:
@@ -529,15 +544,10 @@ def _save_checkpoint(self, step, logs=None):
529544
)
530545

531546
# Apply sharding if specified (JAX only)
547+
save_kwargs = {}
532548
if self.sharding is not None and backend.backend() == "jax":
533-
# For JAX sharding, we need to ensure the data is properly
534-
# sharded
535-
# This is typically handled automatically by Orbax when JAX
536-
# arrays with sharding metadata are saved
537-
if hasattr(save_args, "sharding"):
538-
save_args.sharding = self.sharding
539-
540-
self.manager.save(step, args=save_args)
549+
save_kwargs["sharding"] = self.sharding
550+
self.manager.save(step, args=save_args, **save_kwargs)
541551

542552
def on_train_batch_end(self, batch, logs=None):
543553
if self._should_save_on_batch(batch):
@@ -650,14 +660,14 @@ def load_checkpoint(self, step, model=None):
650660
restore_args = ocp.args.StandardRestore()
651661

652662
# Apply sharding if specified (JAX only)
663+
restore_kwargs = {}
653664
if self.sharding is not None and backend.backend() == "jax":
654-
# For JAX sharding, we need to ensure the data is properly restored
655-
# with the same sharding specification used during save
656-
if hasattr(restore_args, "sharding"):
657-
restore_args.sharding = self.sharding
665+
restore_kwargs["sharding"] = self.sharding
658666

659667
# Load the checkpoint
660-
checkpoint_data = self.manager.restore(step, args=restore_args)
668+
checkpoint_data = self.manager.restore(
669+
step, args=restore_args, **restore_kwargs
670+
)
661671

662672
# Restore the model state
663673
target_model = model if model is not None else self.model

keras/src/callbacks/orbax_checkpoint_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2044,10 +2044,14 @@ def test_restore_unsharded_checkpoint_to_sharded_model(self):
20442044
"unsharded checkpoint",
20452045
)
20462046

2047+
@pytest.mark.skipif(
2048+
backend.backend() != "jax",
2049+
reason="Sharding validation tests require JAX backend",
2050+
)
20472051
def test_invalid_sharding_argument_raises_error(self):
20482052
"""Test that invalid sharding arguments raise TypeError."""
20492053
# Test with string (invalid sharding object)
2050-
with self.assertRaises((TypeError, ValueError)):
2054+
with self.assertRaises(TypeError):
20512055
OrbaxCheckpoint(
20522056
directory=os.path.join(self.temp_dir, "test_invalid_sharding"),
20532057
sharding="invalid_sharding_string",

0 commit comments

Comments
 (0)