Skip to content

Commit 822396f

Browse files
Improve OrbaxCheckpoint implementation
- Remove conditional export decorator to ensure OrbaxCheckpoint is always available - Remove unnecessary exception handling in state tree operations - Update process index check comment for clarity - Format code to comply with 80-character line limit - Add distribution_lib modules for backend-specific distributed training support
1 parent 7742139 commit 822396f

File tree

12 files changed

+366
-226
lines changed

12 files changed

+366
-226
lines changed

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from keras.src.backend.config import is_nnx_enabled
22
from keras.src.backend.jax import core
3-
from keras.src.backend.jax import distribution_lib
43
from keras.src.backend.jax import image
54
from keras.src.backend.jax import linalg
65
from keras.src.backend.jax import math
@@ -29,3 +28,4 @@
2928
from keras.src.backend.jax.rnn import gru
3029
from keras.src.backend.jax.rnn import lstm
3130
from keras.src.backend.jax.rnn import rnn
31+
from keras.src.backend.jax.distribution_lib import process_id

keras/src/backend/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from keras.src.backend.numpy.rnn import gru
2525
from keras.src.backend.numpy.rnn import lstm
2626
from keras.src.backend.numpy.rnn import rnn
27+
from keras.src.backend.numpy.distribution_lib import process_id
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Utilities for distribution strategy with NumPy backend."""
2+
3+
4+
def process_id():
5+
"""Return the current process ID for the distribution setting."""
6+
return 0

keras/src/backend/openvino/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from keras.src.backend.common.name_scope import name_scope
22
from keras.src.backend.openvino import core
3+
from keras.src.backend.openvino import distribution_lib
34
from keras.src.backend.openvino import image
45
from keras.src.backend.openvino import linalg
56
from keras.src.backend.openvino import math
@@ -23,3 +24,4 @@
2324
from keras.src.backend.openvino.rnn import gru
2425
from keras.src.backend.openvino.rnn import lstm
2526
from keras.src.backend.openvino.rnn import rnn
27+
from keras.src.backend.openvino.distribution_lib import process_id
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Utilities for distribution strategy with OpenVINO backend."""
2+
3+
4+
def process_id():
5+
"""Return the current process ID for the distribution setting."""
6+
return 0

keras/src/backend/tensorflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from keras.src.backend.tensorflow import core
2-
from keras.src.backend.tensorflow import distribution_lib
32
from keras.src.backend.tensorflow import image
43
from keras.src.backend.tensorflow import linalg
54
from keras.src.backend.tensorflow import math
@@ -28,3 +27,4 @@
2827
from keras.src.backend.tensorflow.rnn import gru
2928
from keras.src.backend.tensorflow.rnn import lstm
3029
from keras.src.backend.tensorflow.rnn import rnn
30+
from keras.src.backend.tensorflow.distribution_lib import process_id

keras/src/backend/tensorflow/distribution_lib.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,13 @@ def _to_backend_layout(tensor_layout):
8585
]
8686
dtensor_mesh = tensor_layout.device_mesh.backend_mesh
8787
return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)
88+
89+
90+
def process_id():
91+
"""Return the current process ID for the distribution setting."""
92+
try:
93+
import tensorflow as tf
94+
95+
return tf.distribute.get_replica_context().replica_id_in_sync_group
96+
except (ImportError, AttributeError, RuntimeError):
97+
return 0

keras/src/backend/torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@
4343
from keras.src.backend.torch.rnn import gru
4444
from keras.src.backend.torch.rnn import lstm
4545
from keras.src.backend.torch.rnn import rnn
46+
from keras.src.backend.torch.distribution_lib import process_id
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Utilities for distribution strategy with PyTorch backend."""
2+
3+
4+
def process_id():
5+
"""Return the current process ID for the distribution setting."""
6+
try:
7+
import torch.distributed as dist
8+
9+
if dist.is_available() and dist.is_initialized():
10+
return dist.get_rank()
11+
return 0
12+
except (ImportError, AttributeError):
13+
return 0

keras/src/callbacks/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
99
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
1010
from keras.src.callbacks.monitor_callback import MonitorCallback
11-
12-
try:
13-
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
14-
except ImportError:
15-
OrbaxCheckpoint = None
16-
11+
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
1712
from keras.src.callbacks.progbar_logger import ProgbarLogger
1813
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
1914
from keras.src.callbacks.remote_monitor import RemoteMonitor

0 commit comments

Comments
 (0)