Skip to content
Open
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
35 changes: 35 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,38 @@ class name_scope(backend_name_scope):
@keras_export("keras.device")
def device(device_name):
return device_scope(device_name) # noqa: F405


def get_process_index():
"""Get the index of the current process in a distributed setup.

Returns:
int: The process index (0 for primary process, >0 for others).
Returns 0 if not in a distributed setup.
"""
backend_name = backend()
if backend_name == "jax":
try:
import jax

return jax.process_index()
except (ImportError, AttributeError):
return 0
elif backend_name == "tensorflow":
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
elif backend_name == "torch":
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
else:
return 0
6 changes: 6 additions & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback

try:
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
except ImportError:
OrbaxCheckpoint = None

from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading