Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import broadcast_shape
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -588,9 +589,11 @@ def svd_uv_merge(fgraph, node):
@node_rewriter([Blockwise])
def rewrite_inv_inv(fgraph, node):
"""
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))),
we get back our original input without having to compute inverse once.

Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to
be simply rewritten.

Parameters
----------
Expand Down Expand Up @@ -855,6 +858,91 @@ def rewrite_det_kronecker(fgraph, node):
return [det_final]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@node_rewriter([Blockwise])
def rewrite_solve_kron_to_solve(fgraph, node):
"""
Given a linear system of the form:

.. math::

(A \\otimes B) x = y

Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further,
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as:

.. math::

x = \text{vec}(B^{-1} Y A^{-T})

Eliminating the kronecker product from the expression.
"""

if not isinstance(node.op.core_op, SolveBase):
return

solve_op = node.op
props_dict = solve_op.core_op._props_dict()
b_ndim = props_dict["b_ndim"]

A, b = node.inputs
[old_res] = node.outputs

if not (
A.owner
and (
isinstance(A.owner.op, KroneckerProduct)
or (
isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
)
)
):
return

x1, x2 = A.owner.inputs

# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid.
# We will proceed if they are unknown, but this makes the rewrite shape unsafe.
Comment on lines +906 to +907
Copy link
Member

@ricardoV94 ricardoV94 Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape_unsafe is when a rewrite can mask an originally invalid graph, but it / we aren't allowed to turn a previously valid graph into an invalid one. Is that what's happening here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more backwards. The original graph can be valid (at least the computation will be performed -- no idea if the results make sense) but the rewrite might trigger a shape error at runtime that the user wouldn't expect or know how to debug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the wiki I thought that it wouldn't be invertible if each kron is not invertible, which can only be the case for square stuff? Or numerical precision may mask some non-invertibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's true. But we end up able to get a numerical solution anyway.

Anyway I want us to stop raising errors on linalg failure and return NaN instead. If we did this we would hit this "bug" here as well, where an ill-defined system (which would have returned NaN) instead shape errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this specific case, I feel it's fine to let it fail with a shape error in solve. It's also not something that's going to work some times and fail others (say like solve during early mcmc tuning), it will always fail as it is a structural thing?

It's also the simplest approach so we can complexify later if there's a need for it?

x1_core_shapes = x1.type.shape[-2:]
x2_core_shapes = x2.type.shape[-2:]

if (
all(shape is not None for shape in x1_core_shapes)
and x1_core_shapes[-1] != x1_core_shapes[-2]
) or (
all(shape is not None for shape in x2_core_shapes)
and x2_core_shapes[-1] != x2_core_shapes[-2]
):
return None

m, n = x1.shape[-2], x2.shape[-2]
batch_shapes = broadcast_shape(x1, x2)[:-2]

if b_ndim == 1:
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new
# solve node with n_ndim = 2
props_dict["b_ndim"] = 2
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict))
B = b.reshape((*batch_shapes, m, n))
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1))

else:
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional
# batch dimension
b_batch = b.shape[-1]
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n))

res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape(
(*batch_shapes, -1, b_batch)
)

copy_stack_trace(old_res, res)

return [res]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
Expand Down
150 changes: 150 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,156 @@ def test_slogdet_kronecker_rewrite():
)


def count_kron_ops(fgraph):
return sum(
[
isinstance(node.op, KroneckerProduct)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, KroneckerProduct)
)
for node in fgraph.apply_nodes
]
)


@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"])
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"])
@pytest.mark.parametrize(
"solve_op, solve_kwargs",
[
(pt.linalg.solve, {"assume_a": "gen"}),
(pt.linalg.solve, {"assume_a": "pos"}),
(pt.linalg.solve, {"assume_a": "upper triangular"}),
],
ids=["general", "positive definite", "triangular"],
)
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs):
# A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise
# the rewrite is invalid.
a_shape = (3, 3) if not add_batch else (2, 3, 3)
b_shape = (2, 2) if not add_batch else (2, 2, 2)
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)

m, n = a_shape[-2], b_shape[-2]
y_shape = (m * n,)
if b_ndim == 2:
y_shape = (m * n, 3)
if add_batch:
y_shape = (2, *y_shape)

y = pt.tensor("y", shape=y_shape)
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B)

x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim)

fn_expected = pytensor.function(
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve")
)
assert count_kron_ops(fn_expected.maker.fgraph) == 1

fn = pytensor.function([A, B, y], x)
assert count_kron_ops(fn.maker.fgraph) == 0

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape)
b_val = rng.normal(size=b_shape)
y_val = rng.normal(size=y_shape)

if solve_kwargs["assume_a"] == "pos":
a_val = a_val @ np.moveaxis(a_val, -2, -1)
b_val = b_val @ np.moveaxis(b_val, -2, -1)
elif solve_kwargs["assume_a"] == "upper triangular":
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1)
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1)

if len(a_shape) > 2:
a_idx = (slice(None, None), *a_idx)
if len(b_shape) > 2:
b_idx = (slice(None, None), *b_idx)

a_val[a_idx] = 0
b_val[b_idx] = 0

a_val = a_val.astype(config.floatX)
b_val = b_val.astype(config.floatX)
y_val = y_val.astype(config.floatX)

expected = fn_expected(a_val, b_val, y_val)
result = fn(a_val, b_val, y_val)

if config.floatX == "float64":
tol = 1e-8
elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos":
tol = 1e-4
else:
# Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a
# good theory of why. Skipping this case would also be an option.
tol = 1e-2

np.testing.assert_allclose(
expected,
result,
atol=tol,
rtol=tol,
)


def test_rewrite_solve_kron_to_solve_not_applied():
# Check that the rewrite is not applied when the component matrices to the kron are static and not square
A = pt.tensor("A", shape=(3, 2))
B = pt.tensor("B", shape=(2, 3))
C = pt.linalg.kron(A, B)

y = pt.vector("y", shape=(6,))
x = pt.linalg.solve(C, y)

fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 1

# If shapes are static, it should always be applied
A = pt.tensor("A", shape=(3, None, None))
B = pt.tensor("B", shape=(3, None, None))
Comment on lines +939 to +941
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back to the previous comment, is the previous C a valid graph? If so, we can't rewrite and break the graph if we don't know the core shapes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C is valid, because C is square. The "problem" is that we can kron together two non-square matrices and end up with a square one (e.g. kron((4,3), (3,4)) -> (7, 7)). So the rewrite is invalid in this case.

This is another case where we really really wish we had a tag for "square matrix", without having to commit to shapes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wiki seems to suggest Kron(A, B) is only invertible if both A and B are invertible, so you couldn't solve C in the first place if this wasn't the case?

Is that correct? In that case it's fine to have the rewrite when the shapes are unknown (perhaps add a comment?). Otherwise it's not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The theory looks right.

The only issue I guess is that currently, you won't get an error if you have an "invalid" graph like:

A = rng.normal(size=(4, 3))
B = rng.normal(size=(3, 4))

A_pt, B_pt = pt.dmatrices('A', 'B')
y_pt = pt.dvector('y')
C = pt.linalg.kron(A_pt, B_pt)
x = pt.linalg.solve(C, y_pt)

fn = pytensor.function([A_pt, B_pt, y_pt], x)

You get a warning about numerical instability, but it gives you some numbers. Obviously these numbers are just nonsense, but it doesn't error. After the rewrite, you will get a shape error, which might be very surprising for someone who isn't providing a valid graph in the first place?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solve of C doesn't raise for "singular matrix"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but the condition number is super gnarly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to just tag with shape_unsafe, so users can disable to debug.

C = pt.linalg.kron(A, B)
y = pt.tensor("y", shape=(None,))
x = pt.linalg.solve(C, y)
fn = pytensor.function([A, B, y], x)

assert count_kron_ops(fn.maker.fgraph) == 0


@pytest.mark.parametrize(
"a_shape, b_shape",
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))],
ids=["small", "medium", "large"],
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark):
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape)
C = pt.linalg.kron(A, B)

m, n = a_shape[-2], b_shape[-2]
has_batch = len(a_shape) == 3
y_shape = (a_shape[0], m * n) if has_batch else (m * n,)
y = pt.tensor("y", shape=y_shape)
x = pt.linalg.solve(C, y, b_ndim=1)

rng = np.random.default_rng(sum(map(ord, "Go away Kron!")))
a_val = rng.normal(size=a_shape).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)
y_val = rng.normal(size=y_shape).astype(config.floatX)

mode = (
get_default_mode()
if rewrite
else get_default_mode().excluding("rewrite_solve_kron_to_solve")
)

fn = pytensor.function([A, B, y], x, mode=mode)
benchmark(fn, a_val, b_val, y_val)


def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
Expand Down