Skip to content

Commit 621f566

Browse files
Migration to Orbax V1
1 parent 7722e30 commit 621f566

File tree

10 files changed

+411
-664
lines changed

10 files changed

+411
-664
lines changed

keras/src/backend/jax/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from keras.src.backend.jax.core import cast
1515
from keras.src.backend.jax.core import compute_output_spec
1616
from keras.src.backend.jax.core import cond
17-
from keras.src.backend.jax.core import convert_checkpoint_value
1817
from keras.src.backend.jax.core import convert_to_numpy
1918
from keras.src.backend.jax.core import convert_to_tensor
2019
from keras.src.backend.jax.core import device_scope

keras/src/backend/jax/core.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -572,35 +572,3 @@ def device_scope(device_name):
572572
else:
573573
jax_device = device_name
574574
return jax.default_device(jax_device)
575-
576-
577-
def convert_checkpoint_value(value, dtype, shape):
578-
"""Convert a value for checkpoint restoration, preserving JAX arrays for
579-
sharding.
580-
581-
This function handles the special case of checkpoint restoration where JAX
582-
arrays should be preserved for sharding support, while other values are
583-
converted to JAX arrays with the specified dtype and shape.
584-
585-
Args:
586-
value: The value to convert (can be JAX array, numpy array, or other
587-
types)
588-
dtype: The target dtype
589-
shape: The target shape
590-
591-
Returns:
592-
A JAX array with the specified dtype and shape, or the original JAX
593-
array if it was already a JAX array.
594-
"""
595-
# For JAX backend, preserve JAX arrays for sharding support
596-
if hasattr(value, "__array_namespace__") or str(type(value)).startswith(
597-
"<class 'jax"
598-
):
599-
# value is already a JAX array, return as-is to preserve sharding
600-
return value
601-
elif isinstance(value, np.ndarray):
602-
# Convert numpy array to JAX array
603-
return jnp.array(value).astype(dtype).reshape(shape)
604-
else:
605-
# Convert other types to JAX array
606-
return jnp.array(value, dtype=dtype).reshape(shape)

keras/src/backend/openvino/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from keras.src.backend.common.name_scope import name_scope
22
from keras.src.backend.openvino import core
3-
from keras.src.backend.openvino import distribution_lib
43
from keras.src.backend.openvino import image
54
from keras.src.backend.openvino import linalg
65
from keras.src.backend.openvino import math

keras/src/backend/tensorflow/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from keras.src.backend.tensorflow.core import cast
1414
from keras.src.backend.tensorflow.core import compute_output_spec
1515
from keras.src.backend.tensorflow.core import cond
16-
from keras.src.backend.tensorflow.core import convert_checkpoint_value
1716
from keras.src.backend.tensorflow.core import convert_to_numpy
1817
from keras.src.backend.tensorflow.core import convert_to_tensor
1918
from keras.src.backend.tensorflow.core import device_scope

keras/src/backend/tensorflow/core.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -696,23 +696,3 @@ def __exit__(self, *args, **kwargs):
696696

697697
def device_scope(device_name):
698698
return tf.device(device_name)
699-
700-
701-
def convert_checkpoint_value(value, dtype, shape):
702-
"""Convert a value for checkpoint restoration.
703-
704-
For TensorFlow backend, convert to numpy arrays with specified dtype and
705-
shape.
706-
707-
Args:
708-
value: The value to convert
709-
dtype: The target dtype
710-
shape: The target shape
711-
712-
Returns:
713-
A numpy array with the specified dtype and shape.
714-
"""
715-
if isinstance(value, np.ndarray):
716-
return value.astype(dtype).reshape(shape)
717-
else:
718-
return np.array(value, dtype=dtype).reshape(shape)

keras/src/backend/torch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from keras.src.backend.torch.core import cast
3030
from keras.src.backend.torch.core import compute_output_spec
3131
from keras.src.backend.torch.core import cond
32-
from keras.src.backend.torch.core import convert_checkpoint_value
3332
from keras.src.backend.torch.core import convert_to_numpy
3433
from keras.src.backend.torch.core import convert_to_tensor
3534
from keras.src.backend.torch.core import device_scope

keras/src/backend/torch/core.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -730,22 +730,3 @@ def backward(ctx, grad_output):
730730
if not isinstance(grads, tuple):
731731
grads = (grads,)
732732
return (None,) + grads
733-
734-
735-
def convert_checkpoint_value(value, dtype, shape):
736-
"""Convert a value for checkpoint restoration.
737-
738-
For PyTorch backend, convert to numpy arrays with specified dtype and shape.
739-
740-
Args:
741-
value: The value to convert
742-
dtype: The target dtype
743-
shape: The target shape
744-
745-
Returns:
746-
A numpy array with the specified dtype and shape.
747-
"""
748-
if isinstance(value, np.ndarray):
749-
return value.astype(dtype).reshape(shape)
750-
else:
751-
return np.array(value, dtype=dtype).reshape(shape)

0 commit comments

Comments
 (0)