Skip to content

Commit eb7855d

Browse files
Migration to Orbax V1
1 parent 7722e30 commit eb7855d

File tree

20 files changed

+430
-819
lines changed

20 files changed

+430
-819
lines changed

keras/api/_tf_keras/keras/distribution/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from keras.src.distribution.distribution_lib import distribution as distribution
1818
from keras.src.distribution.distribution_lib import initialize as initialize
1919
from keras.src.distribution.distribution_lib import list_devices as list_devices
20-
from keras.src.distribution.distribution_lib import process_id as process_id
2120
from keras.src.distribution.distribution_lib import (
2221
set_distribution as set_distribution,
2322
)

keras/api/distribution/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from keras.src.distribution.distribution_lib import distribution as distribution
1818
from keras.src.distribution.distribution_lib import initialize as initialize
1919
from keras.src.distribution.distribution_lib import list_devices as list_devices
20-
from keras.src.distribution.distribution_lib import process_id as process_id
2120
from keras.src.distribution.distribution_lib import (
2221
set_distribution as set_distribution,
2322
)

keras/src/backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
# Import backend functions.
3737
if backend() == "tensorflow":
3838
from keras.src.backend.tensorflow import * # noqa: F403
39+
from keras.src.backend.tensorflow import distribution_lib
3940
from keras.src.backend.tensorflow.core import Variable as BackendVariable
4041
elif backend() == "jax":
4142
from keras.src.backend.jax import * # noqa: F403
43+
from keras.src.backend.jax import distribution_lib
4244
from keras.src.backend.jax.core import Variable as BackendVariable
4345
elif backend() == "torch":
4446
from keras.src.backend.torch import * # noqa: F403

keras/src/backend/jax/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from keras.src.backend.jax.core import cast
1515
from keras.src.backend.jax.core import compute_output_spec
1616
from keras.src.backend.jax.core import cond
17-
from keras.src.backend.jax.core import convert_checkpoint_value
1817
from keras.src.backend.jax.core import convert_to_numpy
1918
from keras.src.backend.jax.core import convert_to_tensor
2019
from keras.src.backend.jax.core import device_scope
@@ -25,7 +24,6 @@
2524
from keras.src.backend.jax.core import shape
2625
from keras.src.backend.jax.core import stop_gradient
2726
from keras.src.backend.jax.core import vectorized_map
28-
from keras.src.backend.jax.distribution_lib import process_id
2927
from keras.src.backend.jax.rnn import cudnn_ok
3028
from keras.src.backend.jax.rnn import gru
3129
from keras.src.backend.jax.rnn import lstm

keras/src/backend/jax/core.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -572,35 +572,3 @@ def device_scope(device_name):
572572
else:
573573
jax_device = device_name
574574
return jax.default_device(jax_device)
575-
576-
577-
def convert_checkpoint_value(value, dtype, shape):
578-
"""Convert a value for checkpoint restoration, preserving JAX arrays for
579-
sharding.
580-
581-
This function handles the special case of checkpoint restoration where JAX
582-
arrays should be preserved for sharding support, while other values are
583-
converted to JAX arrays with the specified dtype and shape.
584-
585-
Args:
586-
value: The value to convert (can be JAX array, numpy array, or other
587-
types)
588-
dtype: The target dtype
589-
shape: The target shape
590-
591-
Returns:
592-
A JAX array with the specified dtype and shape, or the original JAX
593-
array if it was already a JAX array.
594-
"""
595-
# For JAX backend, preserve JAX arrays for sharding support
596-
if hasattr(value, "__array_namespace__") or str(type(value)).startswith(
597-
"<class 'jax"
598-
):
599-
# value is already a JAX array, return as-is to preserve sharding
600-
return value
601-
elif isinstance(value, np.ndarray):
602-
# Convert numpy array to JAX array
603-
return jnp.array(value).astype(dtype).reshape(shape)
604-
else:
605-
# Convert other types to JAX array
606-
return jnp.array(value, dtype=dtype).reshape(shape)

keras/src/backend/jax/distribution_lib.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,6 @@ def num_processes():
193193
return jax.process_count()
194194

195195

196-
def process_id():
197-
"""Return the current process ID for the distribution setting."""
198-
return jax.process_index()
199-
200-
201196
def _to_backend_device(device_name):
202197
if isinstance(device_name, jax.Device):
203198
return device_name

keras/src/backend/numpy/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from keras.src.backend.numpy.core import random_seed_dtype
2121
from keras.src.backend.numpy.core import shape
2222
from keras.src.backend.numpy.core import vectorized_map
23-
from keras.src.backend.numpy.distribution_lib import process_id
2423
from keras.src.backend.numpy.rnn import cudnn_ok
2524
from keras.src.backend.numpy.rnn import gru
2625
from keras.src.backend.numpy.rnn import lstm
2726
from keras.src.backend.numpy.rnn import rnn
27+
28+
# Numpy backend does not support distribution
29+
distribution_lib = None

keras/src/backend/numpy/distribution_lib.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

keras/src/backend/openvino/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from keras.src.backend.common.name_scope import name_scope
22
from keras.src.backend.openvino import core
3-
from keras.src.backend.openvino import distribution_lib
43
from keras.src.backend.openvino import image
54
from keras.src.backend.openvino import linalg
65
from keras.src.backend.openvino import math
@@ -20,7 +19,6 @@
2019
from keras.src.backend.openvino.core import random_seed_dtype
2120
from keras.src.backend.openvino.core import shape
2221
from keras.src.backend.openvino.core import vectorized_map
23-
from keras.src.backend.openvino.distribution_lib import process_id
2422
from keras.src.backend.openvino.rnn import cudnn_ok
2523
from keras.src.backend.openvino.rnn import gru
2624
from keras.src.backend.openvino.rnn import lstm

keras/src/backend/openvino/distribution_lib.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)