Skip to content

Commit 5229939

Browse files
committed
Add tests to cover frexp and new class
1 parent 5de2f64 commit 5229939

File tree

2 files changed

+218
-62
lines changed

2 files changed

+218
-62
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __call__(
179179
)
180180
elif dtype is not None and out is not None:
181181
raise TypeError(
182-
f"Requested function={self.name_} only takes `out` or `dtype`"
182+
f"Requested function={self.name_} only takes `out` or `dtype` "
183183
"as an argument, but both were provided."
184184
)
185185

@@ -356,6 +356,12 @@ def __call__(
356356

357357
res_dt = res_dts[i]
358358
if res_dt != res.dtype:
359+
if not dpnp.can_cast(res_dt, res.dtype, casting="same_kind"):
360+
raise TypeError(
361+
f"Cannot cast ufunc '{self.name_}' output {i + 1} from "
362+
f"{res_dt} to {res.dtype} with casting rule 'same_kind'"
363+
)
364+
359365
# Allocate a temporary buffer with the required dtype
360366
out[i] = dpt.empty_like(res, dtype=res_dt)
361367
elif (
@@ -564,7 +570,7 @@ def __call__(
564570
)
565571
elif dtype is not None and out is not None:
566572
raise TypeError(
567-
f"Requested function={self.name_} only takes `out` or `dtype`"
573+
f"Requested function={self.name_} only takes `out` or `dtype` "
568574
"as an argument, but both were provided."
569575
)
570576

dpnp/tests/test_mathematical.py

Lines changed: 210 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,112 @@ def test_errors(self):
710710
assert_raises(ExecutionPlacementError, dpnp.ediff1d, ia, to_end=to_end)
711711

712712

713+
class TestFrexp:
714+
ALL_DTYPES = get_all_dtypes(no_none=True)
715+
ALL_DTYPES_NO_COMPLEX = get_all_dtypes(
716+
no_none=True, no_float16=False, no_complex=True
717+
)
718+
719+
@pytest.mark.parametrize("dt", ALL_DTYPES_NO_COMPLEX)
720+
def test_basic(self, dt):
721+
a = numpy.array([-2, 5, 1, 4, 3], dtype=dt)
722+
ia = dpnp.array(a)
723+
724+
res1, res2 = dpnp.frexp(ia)
725+
exp1, exp2 = numpy.frexp(a)
726+
assert_array_equal(res1, exp1)
727+
assert_array_equal(res2, exp2)
728+
729+
def test_out(self):
730+
a = numpy.array(5.7)
731+
ia = dpnp.array(a)
732+
733+
out1 = numpy.empty(())
734+
out2 = numpy.empty((), dtype=numpy.int32)
735+
iout1, iout2 = dpnp.array(out1), dpnp.array(out2)
736+
737+
res1, res2 = dpnp.frexp(ia, iout1)
738+
exp1, exp2 = numpy.frexp(a, out1)
739+
assert_array_equal(res1, exp1)
740+
assert_array_equal(res2, exp2)
741+
assert res1 is iout1
742+
743+
res1, res2 = dpnp.frexp(ia, None, iout2)
744+
exp1, exp2 = numpy.frexp(a, None, out2)
745+
assert_array_equal(res1, exp1)
746+
assert_array_equal(res2, exp2)
747+
assert res2 is iout2
748+
749+
res1, res2 = dpnp.frexp(ia, iout1, iout2)
750+
exp1, exp2 = numpy.frexp(a, out1, out2)
751+
assert_array_equal(res1, exp1)
752+
assert_array_equal(res2, exp2)
753+
assert res1 is iout1
754+
assert res2 is iout2
755+
756+
@pytest.mark.parametrize("dt", ALL_DTYPES_NO_COMPLEX)
757+
@pytest.mark.parametrize("out1_dt", ALL_DTYPES)
758+
@pytest.mark.parametrize("out2_dt", ALL_DTYPES)
759+
def test_out_all_dtypes(self, dt, out1_dt, out2_dt):
760+
a = numpy.ones(9, dtype=dt)
761+
ia = dpnp.array(a)
762+
763+
out1 = numpy.zeros(9, dtype=out1_dt)
764+
out2 = numpy.zeros(9, dtype=out2_dt)
765+
iout1, iout2 = dpnp.array(out1), dpnp.array(out2)
766+
767+
try:
768+
res1, res2 = dpnp.frexp(ia, out=(iout1, iout2))
769+
except TypeError:
770+
# expect numpy to fail with the same reason
771+
with pytest.raises(TypeError):
772+
_ = numpy.frexp(a, out=(out1, out2))
773+
else:
774+
exp1, exp2 = numpy.frexp(a, out=(out1, out2))
775+
assert_array_equal(res1, exp1)
776+
assert_array_equal(res2, exp2)
777+
assert res1 is iout1
778+
assert res2 is iout2
779+
780+
@pytest.mark.parametrize("stride", [-4, -2, -1, 1, 2, 4])
781+
@pytest.mark.parametrize("dt", get_float_dtypes())
782+
def test_strides_out(self, stride, dt):
783+
a = numpy.array(
784+
[numpy.nan, numpy.nan, numpy.inf, -numpy.inf, 0.0, -0.0, 1.0, -1.0],
785+
dtype=dt,
786+
)
787+
ia = dpnp.array(a)
788+
789+
out_mant = numpy.ones(8, dtype=dt)
790+
out_exp = 2 * numpy.ones(8, dtype="i")
791+
iout_mant, iout_exp = dpnp.array(out_mant), dpnp.array(out_exp)
792+
793+
res1, res2 = dpnp.frexp(
794+
ia[::stride], out=(iout_mant[::stride], iout_exp[::stride])
795+
)
796+
exp1, exp2 = numpy.frexp(
797+
a[::stride], out=(out_mant[::stride], out_exp[::stride])
798+
)
799+
assert_array_equal(res1, exp1)
800+
assert_array_equal(res2, exp2)
801+
802+
assert_array_equal(iout_mant, out_mant)
803+
assert_array_equal(iout_exp, out_exp)
804+
805+
@pytest.mark.parametrize("xp", [numpy, dpnp])
806+
def test_out_wrong_type(self, xp):
807+
a = xp.array(0.5)
808+
with pytest.raises(TypeError, match="'out' must be a tuple of arrays"):
809+
_ = xp.frexp(a, out=xp.empty(()))
810+
811+
@pytest.mark.parametrize("xp", [numpy, dpnp])
812+
@pytest.mark.parametrize("dt", get_complex_dtypes())
813+
def test_complex_dtype(self, xp, dt):
814+
a = xp.array([-2, 5, 1, 4, 3], dtype=dt)
815+
with pytest.raises((TypeError, ValueError)):
816+
_ = xp.frexp(a)
817+
818+
713819
class TestGradient:
714820
@pytest.mark.parametrize("dt", get_all_dtypes(no_none=True, no_bool=True))
715821
def test_basic(self, dt):
@@ -1925,6 +2031,110 @@ def test_ndim(self):
19252031
assert_dtype_allclose(result, expected)
19262032

19272033

2034+
class TestUfunc:
2035+
@pytest.mark.parametrize(
2036+
"func, nin, nout",
2037+
[
2038+
pytest.param("abs", 1, 1, id="DPNPUnaryFunc"),
2039+
pytest.param("frexp", 1, 2, id="DPNPUnaryTwoOutputsFunc"),
2040+
pytest.param("add", 2, 1, id="DPNPBinaryFunc"),
2041+
],
2042+
)
2043+
def test_nin_nout(self, func, nin, nout):
2044+
assert getattr(dpnp, func).nin == nin
2045+
assert getattr(dpnp, func).nout == nout
2046+
2047+
@pytest.mark.parametrize(
2048+
"func, kwargs",
2049+
[
2050+
pytest.param(
2051+
"abs",
2052+
{"unknown_kwarg": 1, "where": False, "subok": False},
2053+
id="DPNPUnaryFunc",
2054+
),
2055+
pytest.param(
2056+
"frexp",
2057+
{
2058+
"unknown_kwarg": 1,
2059+
"where": False,
2060+
"dtype": "?",
2061+
"subok": False,
2062+
},
2063+
id="DPNPUnaryTwoOutputsFunc",
2064+
),
2065+
pytest.param(
2066+
"add",
2067+
{"unknown_kwarg": 1, "where": False, "subok": False},
2068+
id="DPNPBinaryFunc",
2069+
),
2070+
],
2071+
)
2072+
def test_not_supported_kwargs(self, func, kwargs):
2073+
x = dpnp.array([1, 2, 3])
2074+
2075+
fn = getattr(dpnp, func)
2076+
args = [x] * fn.nin
2077+
for key, val in kwargs.items():
2078+
with pytest.raises(NotImplementedError):
2079+
fn(*args, **{key: val})
2080+
2081+
@pytest.mark.parametrize("func", ["abs", "frexp", "add"])
2082+
@pytest.mark.parametrize("x", [1, [1, 2], numpy.ones(5)])
2083+
def test_wrong_input(self, func, x):
2084+
fn = getattr(dpnp, func)
2085+
args = [x] * fn.nin
2086+
with pytest.raises(TypeError):
2087+
fn(*args)
2088+
2089+
@pytest.mark.parametrize("func", ["add"])
2090+
def test_binary_wrong_input(self, func):
2091+
x = dpnp.array([1, 2, 3])
2092+
with pytest.raises(TypeError):
2093+
getattr(dpnp, func)(x, [1, 2])
2094+
with pytest.raises(TypeError):
2095+
getattr(dpnp, func)([1, 2], x)
2096+
2097+
@pytest.mark.parametrize("func", ["abs", "frexp", "add"])
2098+
def test_wrong_order(self, func):
2099+
x = dpnp.array([1, 2, 3])
2100+
2101+
fn = getattr(dpnp, func)
2102+
args = [x] * fn.nin
2103+
with pytest.raises(ValueError, match="order must be one of"):
2104+
fn(*args, order="H")
2105+
2106+
@pytest.mark.parametrize("func", ["abs", "add"])
2107+
def test_out_dtype(self, func):
2108+
x = dpnp.array([1, 2, 3])
2109+
out = dpnp.array([1, 2, 3])
2110+
2111+
fn = getattr(dpnp, func)
2112+
args = [x] * fn.nin
2113+
with pytest.raises(
2114+
TypeError, match="only takes `out` or `dtype` as an argument"
2115+
):
2116+
fn(*args, out=out, dtype="f4")
2117+
2118+
@pytest.mark.parametrize("func", ["abs", "frexp", "add"])
2119+
def test_order_none(self, func):
2120+
a = numpy.array([1, 2, 3])
2121+
ia = dpnp.array(a)
2122+
2123+
fn = getattr(numpy, func)
2124+
ifn = getattr(dpnp, func)
2125+
2126+
args = [a] * fn.nin
2127+
iargs = [ia] * ifn.nin
2128+
2129+
result = ifn(*iargs, order=None)
2130+
expected = fn(*args, order=None)
2131+
if fn.nout == 1:
2132+
assert_dtype_allclose(result, expected)
2133+
else:
2134+
for i in range(fn.nout):
2135+
assert_dtype_allclose(result[i], expected[i])
2136+
2137+
19282138
class TestUnwrap:
19292139
@pytest.mark.parametrize("dt", get_float_dtypes())
19302140
def test_basic(self, dt):
@@ -2568,66 +2778,6 @@ def test_inplace_floor_divide(dtype):
25682778
assert_allclose(ia, a)
25692779

25702780

2571-
def test_elemenwise_nin_nout():
2572-
assert dpnp.abs.nin == 1
2573-
assert dpnp.add.nin == 2
2574-
2575-
assert dpnp.abs.nout == 1
2576-
assert dpnp.add.nout == 1
2577-
2578-
2579-
def test_elemenwise_error():
2580-
x = dpnp.array([1, 2, 3])
2581-
out = dpnp.array([1, 2, 3])
2582-
2583-
with pytest.raises(NotImplementedError):
2584-
dpnp.abs(x, unknown_kwarg=1)
2585-
with pytest.raises(NotImplementedError):
2586-
dpnp.abs(x, where=False)
2587-
with pytest.raises(NotImplementedError):
2588-
dpnp.abs(x, subok=False)
2589-
with pytest.raises(TypeError):
2590-
dpnp.abs(1)
2591-
with pytest.raises(TypeError):
2592-
dpnp.abs([1, 2])
2593-
with pytest.raises(TypeError):
2594-
dpnp.abs(x, out=out, dtype="f4")
2595-
with pytest.raises(ValueError):
2596-
dpnp.abs(x, order="H")
2597-
2598-
with pytest.raises(NotImplementedError):
2599-
dpnp.add(x, x, unknown_kwarg=1)
2600-
with pytest.raises(NotImplementedError):
2601-
dpnp.add(x, x, where=False)
2602-
with pytest.raises(NotImplementedError):
2603-
dpnp.add(x, x, subok=False)
2604-
with pytest.raises(TypeError):
2605-
dpnp.add(1, 2)
2606-
with pytest.raises(TypeError):
2607-
dpnp.add([1, 2], [1, 2])
2608-
with pytest.raises(TypeError):
2609-
dpnp.add(x, [1, 2])
2610-
with pytest.raises(TypeError):
2611-
dpnp.add([1, 2], x)
2612-
with pytest.raises(TypeError):
2613-
dpnp.add(x, x, out=out, dtype="f4")
2614-
with pytest.raises(ValueError):
2615-
dpnp.add(x, x, order="H")
2616-
2617-
2618-
def test_elemenwise_order_none():
2619-
x_np = numpy.array([1, 2, 3])
2620-
x = dpnp.array([1, 2, 3])
2621-
2622-
result = dpnp.abs(x, order=None)
2623-
expected = numpy.abs(x_np, order=None)
2624-
assert_dtype_allclose(result, expected)
2625-
2626-
result = dpnp.add(x, x, order=None)
2627-
expected = numpy.add(x_np, x_np, order=None)
2628-
assert_dtype_allclose(result, expected)
2629-
2630-
26312781
def test_bitwise_1array_input():
26322782
x = dpnp.array([1, 2, 3])
26332783
x_np = numpy.array([1, 2, 3])

0 commit comments

Comments
 (0)