Skip to content
Open
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
35 changes: 35 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,38 @@ class name_scope(backend_name_scope):
@keras_export("keras.device")
def device(device_name):
return device_scope(device_name) # noqa: F405


def get_process_index():
"""Get the index of the current process in a distributed setup.

Returns:
int: The process index (0 for primary process, >0 for others).
Returns 0 if not in a distributed setup.
"""
backend_name = backend()
if backend_name == "jax":
try:
import jax

return jax.process_index()
except (ImportError, AttributeError):
return 0
elif backend_name == "tensorflow":
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
elif backend_name == "torch":
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
else:
return 0
2 changes: 1 addition & 1 deletion keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
from keras.src.backend.jax import math
Expand All @@ -25,6 +24,7 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.distribution_lib import process_id
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras.src.backend.numpy.core import random_seed_dtype
from keras.src.backend.numpy.core import shape
from keras.src.backend.numpy.core import vectorized_map
from keras.src.backend.numpy.distribution_lib import process_id
from keras.src.backend.numpy.rnn import cudnn_ok
from keras.src.backend.numpy.rnn import gru
from keras.src.backend.numpy.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/numpy/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with NumPy backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.openvino import core
from keras.src.backend.openvino import distribution_lib
from keras.src.backend.openvino import image
from keras.src.backend.openvino import linalg
from keras.src.backend.openvino import math
Expand All @@ -19,6 +20,7 @@
from keras.src.backend.openvino.core import random_seed_dtype
from keras.src.backend.openvino.core import shape
from keras.src.backend.openvino.core import vectorized_map
from keras.src.backend.openvino.distribution_lib import process_id
from keras.src.backend.openvino.rnn import cudnn_ok
from keras.src.backend.openvino.rnn import gru
from keras.src.backend.openvino.rnn import lstm
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Utilities for distribution strategy with OpenVINO backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
return 0
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from keras.src.backend.tensorflow import core
from keras.src.backend.tensorflow import distribution_lib
from keras.src.backend.tensorflow import image
from keras.src.backend.tensorflow import linalg
from keras.src.backend.tensorflow import math
Expand All @@ -24,6 +23,7 @@
from keras.src.backend.tensorflow.core import shape
from keras.src.backend.tensorflow.core import stop_gradient
from keras.src.backend.tensorflow.core import vectorized_map
from keras.src.backend.tensorflow.distribution_lib import process_id
from keras.src.backend.tensorflow.rnn import cudnn_ok
from keras.src.backend.tensorflow.rnn import gru
from keras.src.backend.tensorflow.rnn import lstm
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/tensorflow/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ def _to_backend_layout(tensor_layout):
]
dtensor_mesh = tensor_layout.device_mesh.backend_mesh
return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import tensorflow as tf

return tf.distribute.get_replica_context().replica_id_in_sync_group
except (ImportError, AttributeError, RuntimeError):
return 0
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from keras.src.backend.torch.core import stop_gradient
from keras.src.backend.torch.core import to_torch_dtype
from keras.src.backend.torch.core import vectorized_map
from keras.src.backend.torch.distribution_lib import process_id
from keras.src.backend.torch.rnn import cudnn_ok
from keras.src.backend.torch.rnn import gru
from keras.src.backend.torch.rnn import lstm
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/torch/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Utilities for distribution strategy with PyTorch backend."""


def process_id():
"""Return the current process ID for the distribution setting."""
try:
import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
except (ImportError, AttributeError):
return 0
1 change: 1 addition & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
Loading