Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6328350
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…
amitsrivastava78 Oct 21, 2025
ca71da6
Fix unused variable in orbax checkpoint test
amitsrivastava78 Oct 22, 2025
4dfa903
fixed failing cases
amitsrivastava78 Oct 22, 2025
7742139
fixed review comments
amitsrivastava78 Oct 22, 2025
822396f
Improve OrbaxCheckpoint implementation
amitsrivastava78 Oct 24, 2025
61bd5e6
Fix code formatting and remove unused variable
amitsrivastava78 Oct 24, 2025
19d2495
Add OrbaxCheckpoint callback with conditional exports and improved te…
amitsrivastava78 Oct 24, 2025
b56dc7b
Improve OrbaxCheckpoint: preserve nested structures, enhance tests
amitsrivastava78 Oct 28, 2025
7722e30
Fixed review comments
amitsrivastava78 Oct 31, 2025
eb7855d
Migration to Orbax V1
amitsrivastava78 Nov 5, 2025
aaf6e20
Fix sklearn wrapper CI tests by marking pipeline consistency checks a…
amitsrivastava78 Nov 10, 2025
cd881dd
made distributed structure proper
amitsrivastava78 Nov 10, 2025
9417027
Fixed sav decision between keras and orbax
amitsrivastava78 Nov 11, 2025
b7a0dff
Optimize Orbax checkpoint for JAX backend
amitsrivastava78 Nov 11, 2025
33f4e66
Optimize Orbax checkpoint for JAX backend with compatibility check
amitsrivastava78 Nov 11, 2025
d7884ef
added checkpointer.wait()
amitsrivastava78 Nov 12, 2025
13aec2e
Improve OrbaxCheckpoint callback with optimizations and cleanup
amitsrivastava78 Nov 13, 2025
a2938ea
Simplify OrbaxCheckpoint API to match ModelCheckpoint parity
amitsrivastava78 Nov 13, 2025
4d659f4
Removed the experimental import
amitsrivastava78 Nov 13, 2025
ce30b36
Add comprehensive OrbaxCheckpoint tests with loading verification
amitsrivastava78 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
35 changes: 35 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,38 @@ class name_scope(backend_name_scope):
@keras_export("keras.device")
def device(device_name):
return device_scope(device_name) # noqa: F405


def get_process_index():
"""Get the index of the current process in a distributed setup.

Returns:
int: The process index (0 for primary process, >0 for others).
Returns 0 if not in a distributed setup.
"""
backend_name = backend()
if backend_name == "jax":
try:
import jax

return jax.process_index()
except (ImportError, AttributeError):
return 0
elif backend_name == "tensorflow":
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
elif backend_name == "torch":
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
else:
return 0
6 changes: 6 additions & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback

try:
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
except ImportError:
OrbaxCheckpoint = None

from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading
Loading