Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement linear algebra functions in PyTorch #922

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor
import pytensor.link.pytorch.dispatch.slinalg
# isort: on
88 changes: 88 additions & 0 deletions pytensor/link/pytorch/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.slinalg import (
BlockDiagonal,
Cholesky,
Eigvalsh,
Solve,
SolveTriangular,
)


@pytorch_funcify.register(Eigvalsh)
def pytorch_funcify_Eigvalsh(op, **kwargs):
if op.lower:
UPLO = "L"
else:
UPLO = "U"

def eigvalsh(a, b):
if b is not None:
raise NotImplementedError(
"torch.linalg.eigvalsh does not support generalized eigenvector problems (b != None)"
)
return torch.linalg.eigvalsh(a, UPLO=UPLO)

Check warning on line 25 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L25

Added line #L25 was not covered by tests

return eigvalsh


@pytorch_funcify.register(Cholesky)
def pytorch_funcify_Cholesky(op, **kwargs):
upper = not op.lower

def cholesky(a):
return torch.linalg.cholesky(a, upper=upper)

Check warning on line 35 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L35

Added line #L35 was not covered by tests

return cholesky


@pytorch_funcify.register(Solve)
def pytorch_funcify_Solve(op, **kwargs):
def solve(a, b):
return torch.linalg.solve(a, b)

Check warning on line 43 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L43

Added line #L43 was not covered by tests

return solve


@pytorch_funcify.register(SolveTriangular)
def pytorch_funcify_SolveTriangular(op, **kwargs):
if op.check_finite:
raise NotImplementedError(
"Option check_finite is not implemented in torch.linalg.solve_triangular"
)

upper = not op.lower
unit_diagonal = op.unit_diagonal
trans = op.trans

def solve_triangular(A, b):
if trans in [1, "T"]:
A_p = A.T

Check warning on line 61 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L61

Added line #L61 was not covered by tests
elif trans in [2, "C"]:
A_p = A.conj().T

Check warning on line 63 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L63

Added line #L63 was not covered by tests
else:
A_p = A

Check warning on line 65 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L65

Added line #L65 was not covered by tests

b_p = b

Check warning on line 67 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L67

Added line #L67 was not covered by tests
twaclaw marked this conversation as resolved.
Show resolved Hide resolved
if b.ndim == 1:
b_p = b[:, None]

Check warning on line 69 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L69

Added line #L69 was not covered by tests

res = torch.linalg.solve_triangular(

Check warning on line 71 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L71

Added line #L71 was not covered by tests
A_p, b_p, upper=upper, unitriangular=unit_diagonal
)

if b.ndim == 1 and res.shape[1] == 1:
return res.flatten()

Check warning on line 76 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L76

Added line #L76 was not covered by tests

return res

Check warning on line 78 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L78

Added line #L78 was not covered by tests

return solve_triangular


@pytorch_funcify.register(BlockDiagonal)
def pytorch_funcify_BlockDiagonalMatrix(op, **kwargs):
def block_diag(*inputs):
return torch.block_diag(*inputs)

Check warning on line 86 in pytensor/link/pytorch/dispatch/slinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/slinalg.py#L86

Added line #L86 was not covered by tests

return block_diag
129 changes: 129 additions & 0 deletions tests/link/pytorch/test_slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import numpy as np
import pytest

from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor.type import matrix, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py


@pytest.mark.parametrize("lower", [False, True])
def test_pytorch_eigvalsh(lower):
A = matrix("A")
B = matrix("B")

out = pt_slinalg.eigvalsh(A, B, lower=lower)
out_fg = FunctionGraph([A, B], [out])

with pytest.raises(NotImplementedError):
compare_pytorch_and_py(
out_fg,
[
np.array(
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
).astype(config.floatX),
np.array(
[[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]]
).astype(config.floatX),
],
)
compare_pytorch_and_py(
out_fg,
[
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
config.floatX
),
None,
],
)


def test_pytorch_cholesky():
rng = np.random.default_rng(28494)

x = matrix("x")

out = pt_slinalg.cholesky(x)
out_fg = FunctionGraph([x], [out])
compare_pytorch_and_py(
out_fg,
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
)
],
)

out = pt_slinalg.cholesky(x, lower=False)
out_fg = FunctionGraph([x], [out])
compare_pytorch_and_py(
out_fg,
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
)
],
)


def test_pytorch_solve():
x = matrix("x")
b = vector("b")

out = pt_slinalg.solve(x, b)
twaclaw marked this conversation as resolved.
Show resolved Hide resolved
out_fg = FunctionGraph([x, b], [out])
compare_pytorch_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)


@pytest.mark.parametrize(
"check_finite",
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),
)
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2, "S"])
def test_pytorch_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")

out = pt_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
out_fg = FunctionGraph([x, b], [out])
compare_pytorch_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)


def test_pytorch_block_diag():
A = matrix("A")
B = matrix("B")
C = matrix("C")
D = matrix("D")

out = pt_slinalg.block_diag(A, B, C, D)
out_fg = FunctionGraph([A, B, C, D], [out])

compare_pytorch_and_py(
out_fg,
[
np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX),
np.random.normal(size=(2, 2)).astype(config.floatX),
np.random.normal(size=(4, 4)).astype(config.floatX),
],
)
Loading