Skip to content

Commit ff60e34

Browse files
authored
Some dtype fixes (#935)
* Some dtype fixes * Nits
1 parent 180e134 commit ff60e34

File tree

5 files changed

+77
-5
lines changed

5 files changed

+77
-5
lines changed

keras_core/backend/jax/numpy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def append(
113113

114114

115115
def arange(start, stop=None, step=1, dtype=None):
116+
if dtype is None:
117+
if hasattr(start, "dtype"):
118+
dtype = start.dtype
119+
elif isinstance(start, int):
120+
dtype = "int32"
121+
else:
122+
dtype = config.floatx()
116123
return jnp.arange(start, stop, step=step, dtype=dtype)
117124

118125

keras_core/backend/numpy/numpy.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
22

3+
from keras_core.backend import config
4+
from keras_core.backend import standardize_dtype
5+
36

47
def add(x1, x2):
58
return np.add(x1, x2)
@@ -77,6 +80,13 @@ def append(
7780

7881

7982
def arange(start, stop=None, step=None, dtype=None):
83+
if dtype is None:
84+
if hasattr(start, "dtype"):
85+
dtype = start.dtype
86+
elif isinstance(start, int):
87+
dtype = "int32"
88+
else:
89+
dtype = config.floatx()
8090
return np.arange(start, stop, step=step, dtype=dtype)
8191

8292

@@ -124,6 +134,7 @@ def argsort(x, axis=-1):
124134

125135

126136
def array(x, dtype=None):
137+
dtype = dtype or config.floatx()
127138
return np.array(x, dtype=dtype)
128139

129140

@@ -271,6 +282,7 @@ def floor(x):
271282

272283

273284
def full(shape, fill_value, dtype=None):
285+
dtype = dtype or config.floatx()
274286
return np.full(shape, fill_value, dtype=dtype)
275287

276288

@@ -592,7 +604,11 @@ def square(x):
592604

593605

594606
def sqrt(x):
595-
return np.sqrt(x)
607+
dtype = None
608+
if hasattr(x, "dtype"):
609+
if standardize_dtype(x.dtype).startswith("int"):
610+
dtype = config.floatx()
611+
return np.sqrt(x, dtype=dtype)
596612

597613

598614
def squeeze(x, axis=None):

keras_core/backend/tensorflow/numpy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tensorflow as tf
66
from tensorflow.experimental import numpy as tfnp
77

8+
from keras_core.backend import config
89
from keras_core.backend.tensorflow.core import convert_to_tensor
910

1011

@@ -176,6 +177,13 @@ def append(
176177
def arange(start, stop=None, step=1, dtype=None):
177178
# tfnp.arange has trouble with dynamic Tensors in compiled function.
178179
# tf.range does not.
180+
if dtype is None:
181+
if hasattr(start, "dtype"):
182+
dtype = start.dtype
183+
elif isinstance(start, int):
184+
dtype = "int32"
185+
else:
186+
dtype = config.floatx()
179187
return tf.range(start, stop, delta=step, dtype=dtype)
180188

181189

@@ -749,6 +757,9 @@ def square(x):
749757

750758

751759
def sqrt(x):
760+
x = convert_to_tensor(x)
761+
if tf.as_dtype(x.dtype).is_integer:
762+
x = tf.cast(x, dtype=config.floatx())
752763
return tfnp.sqrt(x)
753764

754765

keras_core/backend/torch/numpy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33

4+
from keras_core.backend import config
45
from keras_core.backend.torch.core import cast
56
from keras_core.backend.torch.core import convert_to_tensor
67
from keras_core.backend.torch.core import get_device
@@ -91,7 +92,7 @@ def zeros(shape, dtype="float32"):
9192

9293
def zeros_like(x, dtype=None):
9394
x = convert_to_tensor(x)
94-
dtype = to_torch_dtype(dtype)
95+
dtype = to_torch_dtype(dtype or x.dtype)
9596
return torch.zeros_like(x, dtype=dtype)
9697

9798

@@ -160,6 +161,13 @@ def append(
160161

161162

162163
def arange(start, stop=None, step=1, dtype=None):
164+
if dtype is None:
165+
if hasattr(start, "dtype"):
166+
dtype = start.dtype
167+
elif isinstance(start, int):
168+
dtype = "int32"
169+
else:
170+
dtype = config.floatx()
163171
dtype = to_torch_dtype(dtype)
164172
if stop is None:
165173
return torch.arange(end=start, dtype=dtype, device=get_device())

keras_core/ops/numpy_test.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,9 +3571,37 @@ def test_split(self):
35713571
self.assertEqual(len(knp.Split(2)(x)), 2)
35723572

35733573
def test_sqrt(self):
3574-
x = np.array([[1, 4, 9], [16, 25, 36]])
3575-
self.assertAllClose(knp.sqrt(x), np.sqrt(x))
3576-
self.assertAllClose(knp.Sqrt()(x), np.sqrt(x))
3574+
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
3575+
ref_y = np.sqrt(x)
3576+
y = knp.sqrt(x)
3577+
self.assertEqual(standardize_dtype(y.dtype), "float32")
3578+
self.assertAllClose(y, ref_y)
3579+
y = knp.Sqrt()(x)
3580+
self.assertEqual(standardize_dtype(y.dtype), "float32")
3581+
self.assertAllClose(y, ref_y)
3582+
3583+
@pytest.mark.skipif(
3584+
backend.backend() == "jax", reason="JAX does not support float64."
3585+
)
3586+
def test_sqrt_float64(self):
3587+
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float64")
3588+
ref_y = np.sqrt(x)
3589+
y = knp.sqrt(x)
3590+
self.assertEqual(standardize_dtype(y.dtype), "float64")
3591+
self.assertAllClose(y, ref_y)
3592+
y = knp.Sqrt()(x)
3593+
self.assertEqual(standardize_dtype(y.dtype), "float64")
3594+
self.assertAllClose(y, ref_y)
3595+
3596+
def test_sqrt_int32(self):
3597+
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="int32")
3598+
ref_y = np.sqrt(x)
3599+
y = knp.sqrt(x)
3600+
self.assertEqual(standardize_dtype(y.dtype), "float32")
3601+
self.assertAllClose(y, ref_y)
3602+
y = knp.Sqrt()(x)
3603+
self.assertEqual(standardize_dtype(y.dtype), "float32")
3604+
self.assertAllClose(y, ref_y)
35773605

35783606
def test_stack(self):
35793607
x = np.array([[1, 2, 3], [3, 2, 1]])
@@ -3704,6 +3732,8 @@ def test_arange(self):
37043732
self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7))
37053733
self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2))
37063734

3735+
self.assertEqual(standardize_dtype(knp.arange(3).dtype), "int32")
3736+
37073737
def test_full(self):
37083738
self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0))
37093739
self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1))

0 commit comments

Comments
 (0)