Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.linalg.solve() to align NumPy 2.0 #2198

Merged
merged 14 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,12 +1612,12 @@ def solve(a, b):
----------
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
Coefficient matrix.
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Ordinate or "dependent variable" values.

Returns
-------
out : {(, M,), (, M, K)} dpnp.ndarray
out : {(..., M,), (..., M, K)} dpnp.ndarray
Solution to the system `ax = b`. Returned shape is identical to `b`.

See Also
Expand All @@ -1644,14 +1644,38 @@ def solve(a, b):
assert_stacked_2d(a)
assert_stacked_square(a)

if not (
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
):
raise dpnp.linalg.LinAlgError(
"a must have (..., M, M) shape and b must have (..., M) "
"or (..., M, K)"
a_shape = a.shape
b_shape = b.shape
b_ndim = b.ndim

# compatible with numpy>=2.0
if b_ndim == 0:
raise ValueError("b must have at least one dimension")
if b_ndim == 1:
if a_shape[-1] != b.size:
raise ValueError(
"a must have (..., M, M) shape and b must have (M,) "
"for one-dimensional b"
)
b = dpnp.broadcast_to(b, a_shape[:-1])
return dpnp_solve(a, b)

if a_shape[-1] != b_shape[-2]:
raise ValueError(
"a must have (..., M, M) shape and b must have (..., M, K) shape"
)

# Use dpnp.broadcast_shapes() to align the resulting batch shapes
broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2])

a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]

if a_shape != a_broadcasted_shape:
a = dpnp.broadcast_to(a, a_broadcasted_shape)
if b_shape != b_broadcasted_shape:
b = dpnp.broadcast_to(b, b_broadcasted_shape)

return dpnp_solve(a, b)


Expand Down
30 changes: 30 additions & 0 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,36 @@ def test_solve(self, dtype):

assert_allclose(expected, result, rtol=1e-06)

@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize(
"a_shape, b_shape",
[
((4, 4), (2, 2, 4, 3)),
((2, 5, 5), (1, 5, 3)),
((2, 4, 4), (2, 2, 4, 2)),
((3, 2, 2), (3, 1, 2, 1)),
((2, 2, 2, 2, 2), (2,)),
((2, 2, 2, 2, 2), (2, 3)),
],
)
def test_solve_broadcast(self, a_shape, b_shape, dtype):
# Set seed_value=81 to prevent
# random generation of the input singular matrix
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)

# Set seed_value=76 to prevent
# random generation of the input singular matrix
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)

a_dp = inp.array(a_np)
b_dp = inp.array(b_np)

expected = numpy.linalg.solve(a_np, b_np)
result = inp.linalg.solve(a_dp, b_dp)

assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_solve_nrhs_greater_n(self, dtype):
# Test checking the case when nrhs > n for
Expand Down
15 changes: 8 additions & 7 deletions dpnp/tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def test_solve(self):
# for other cases this signature must be followed
# (..., m, m), (..., m, n) -> (..., m, n)
# https://github.com/numpy/numpy/pull/25914
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_x((2, 4, 4), (2, 4))
self.check_x((2, 3, 2, 2), (2, 3, 2))
self.check_x((0, 2, 2), (0, 2))
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
self.check_x((2, 3, 3), (3,))
self.check_x((2, 5, 3, 3), (3,))

def check_shape(self, a_shape, b_shape, error_types):
for xp, error_type in error_types.items():
Expand Down Expand Up @@ -96,11 +95,13 @@ def test_invalid_shape(self):
self.check_shape((3, 3), (2,), value_errors)
self.check_shape((3, 3), (2, 2), value_errors)
self.check_shape((3, 3, 4), (3,), linalg_errors)
# Since numpy >= 2.0, this case does not raise an error
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_shape((2, 3, 3), (3,), value_errors)
self.check_shape((3, 3), (0,), value_errors)
self.check_shape((0, 3, 4), (3,), linalg_errors)
# Not allowed since numpy 2.0
if numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
self.check_shape((0, 2, 2), (0, 2), value_errors)
self.check_shape((2, 4, 4), (2, 4), value_errors)
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)


@testing.parameterize(
Expand Down
Loading