From e2e32f17747a6fce559a754f48dc5e64dd16c3ee Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 30 May 2023 15:17:17 -0700 Subject: [PATCH 01/23] Rebase vs. main --- linear_operator/operators/__init__.py | 2 + .../operators/block_tensor_linear_operator.py | 84 +++++++++++++++++++ .../test_block_tensor_linear_operator.py | 74 ++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 linear_operator/operators/block_tensor_linear_operator.py create mode 100644 test/operators/test_block_tensor_linear_operator.py diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 79e92108..96ea8a61 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -6,6 +6,7 @@ from .block_diag_linear_operator import BlockDiagLinearOperator from .block_interleaved_linear_operator import BlockInterleavedLinearOperator from .block_linear_operator import BlockLinearOperator +from .block_tensor_linear_operator import BlockTensorLinearOperator from .cat_linear_operator import cat, CatLinearOperator from .chol_linear_operator import CholLinearOperator from .constant_mul_linear_operator import ConstantMulLinearOperator @@ -44,6 +45,7 @@ "BlockLinearOperator", "BlockDiagLinearOperator", "BlockInterleavedLinearOperator", + "BlockTensorLinearOperator", "CatLinearOperator", "CholLinearOperator", "ConstantDiagLinearOperator", diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py new file mode 100644 index 00000000..04b7eba2 --- /dev/null +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -0,0 +1,84 @@ +from typing import List, Union + +import torch +from torch import Tensor + +from ._linear_operator import LinearOperator +from .dense_linear_operator import to_linear_operator + + +class BlockTensorLinearOperator(LinearOperator): + def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: + assert len(linear_operators) > 0, "must have nested list" + assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" + + super().__init__(linear_operators) + + self.linear_operators = linear_operators + self.num_tasks = len(self.linear_operators) + + def _matmul( + self: Float[LinearOperator, "*batch M N"], + rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], + ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: + + T = self.num_tasks + output = [] + for i in range(T): + tmp = [] + for j in range(T): + tmp.append([]) + output.append(tmp) + + if isinstance(other, self.__class__): + # TO DO: Check size is the same + for i in range(T): + for j in range(T): + out_ij = to_linear_operator( + torch.zeros(self.linear_operators[0][0].shape[0], other.linear_operators[0][0].shape[1]) + ) + for k in range(T): + out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] + output[i][j] = out_ij + elif isinstance(other, Tensor): + # Check both matrix dims divisible by T, + # reshape to (T, T, ), call .from_tensor + pass + + elif isinstance(other, LinearOperator): + pass + + else: + raise Exception("") + + return self.__class__(output) + + def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: + out = [] + for i in range(self.num_tasks): + rows = [] + for j in range(self.num_tasks): + rows.append(self.linear_operators[i][j].to_dense()) + out.append(torch.concat(rows, axis=1)) + return torch.concat(out, axis=0) + + def _size(self) -> torch.Size: + sz = self.linear_operators[0][0].size() + return torch.Size([self.num_tasks * sz[0], self.num_tasks * sz[1]]) + + def _diag(self): + out = [] + for i in range(self.num_tasks): + diagonal = self.linear_operators[i][i].diagonal() + out.append(diagonal) + return torch.concat(out, axis=1) + + def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: + return self # Diagonal matrices are symmetric + + @classmethod + def from_tensor(cls, tensor: Tensor, num_tasks: int): + linear_ops = [ + [to_linear_operator(t[0]) for t in list(torch.tensor_split(tensor[i], num_tasks))] for i in range(num_tasks) + ] + return cls(linear_ops) diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py new file mode 100644 index 00000000..8d9e9ebc --- /dev/null +++ b/test/operators/test_block_tensor_linear_operator.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from linear_operator.operators import BlockTensorLinearOperator +from linear_operator.test.base_test_case import BaseTestCase + +# from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase + + +class TestBlockBlockSimple(BaseTestCase, unittest.TestCase): + def test_multiply(self): + T = 2 + N = 4 + M = 3 + K = 5 + + A = torch.randn(T, T, N, M) + B = torch.randn(T, T, M, K) + + A_blo = BlockTensorLinearOperator.from_tensor(A, T) + B_blo = BlockTensorLinearOperator.from_tensor(B, T) + res_AB = A_blo._matmul(B_blo) + res_dense_AB = res_AB.to_dense() + + A_dense = A.permute(0, 2, 1, 3).reshape(T * N, T * M) + B_dense = B.permute(0, 2, 1, 3).reshape(T * M, T * K) + expected = A_dense @ B_dense + self.assertAllClose(res_dense_AB, expected) + self.assertAllClose(A_dense, A_blo.to_dense()) + self.assertAllClose(B_dense, B_blo.to_dense()) + + # Try to convert dense to block + Ne = A_dense.size(0) // T + Me = A_dense.size(1) // T + A_blocks_est = A_dense.reshape(T, Ne, T, Me) + A_blocks_est = A_blocks_est.permute(0, 2, 1, 3) + self.assertAllClose(A, A_blocks_est) + + # Check Tensor multiplication + # res_tensor_AB = A_blo._matmul(B_dense) + # res_tensor_dense_AB = res_tensor_AB.to_dense() + # self.assertAllClose(res_dense_AB, res_tensor_dense_AB) + + +rem = """ + +class TestBlockBlockLinearOperator(LinearOperatorTestCase, unittest.TestCase): + seed = 0 + should_test_sample = False + T = 2 + N = M = 4 # Try a square for this set of tests + # N = 4 + # M = 3 + + A_dense = torch.eye(T * N) + A_blocks = A_dense.reshape(T, N, T, M).permute(0, 2, 1, 3) + + # A = torch.randn(T, T, N, M) # Need to make something +ve definite + + def create_linear_op(self): + A_blo = BlockBLockLinearOperator.from_tensor(self.A_blocks) + return A_blo + + def evaluate_linear_op(self, linear_op): + D = linear_op.to_dense() + return D +""" + + +if __name__ == "__main__": + unittest.main() From 15effdfcb978c70d45d53270948c13fdd1dff520 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 30 May 2023 15:24:02 -0700 Subject: [PATCH 02/23] Fix block tensor type signatures --- .../operators/block_tensor_linear_operator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index 04b7eba2..aad69e88 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -1,6 +1,7 @@ from typing import List, Union import torch +from jaxtyping import Float from torch import Tensor from ._linear_operator import LinearOperator @@ -30,22 +31,22 @@ def _matmul( tmp.append([]) output.append(tmp) - if isinstance(other, self.__class__): + if isinstance(rhs, self.__class__): # TO DO: Check size is the same for i in range(T): for j in range(T): out_ij = to_linear_operator( - torch.zeros(self.linear_operators[0][0].shape[0], other.linear_operators[0][0].shape[1]) + torch.zeros(self.linear_operators[0][0].shape[0], rhs.linear_operators[0][0].shape[1]) ) for k in range(T): - out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] + out_ij += self.linear_operators[i][k] @ rhs.linear_operators[k][j] output[i][j] = out_ij - elif isinstance(other, Tensor): + elif isinstance(rhs, Tensor): # Check both matrix dims divisible by T, # reshape to (T, T, ), call .from_tensor pass - elif isinstance(other, LinearOperator): + elif isinstance(rhs, LinearOperator): pass else: From 4055a7fd2f0b01be86692eb2895c8f2766d6b0eb Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 30 May 2023 17:21:20 -0700 Subject: [PATCH 03/23] Get simple test running for BlockTensor --- .../operators/block_tensor_linear_operator.py | 52 +++++++++++-------- .../test_block_tensor_linear_operator.py | 12 ++--- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index aad69e88..80dc2d92 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -17,6 +17,8 @@ def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: self.linear_operators = linear_operators self.num_tasks = len(self.linear_operators) + self.block_rows = linear_operators[0][0].shape[0] + self.block_cols = linear_operators[0][0].shape[1] def _matmul( self: Float[LinearOperator, "*batch M N"], @@ -24,35 +26,41 @@ def _matmul( ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: T = self.num_tasks - output = [] - for i in range(T): - tmp = [] - for j in range(T): - tmp.append([]) - output.append(tmp) - if isinstance(rhs, self.__class__): - # TO DO: Check size is the same + # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts + # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() + if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: + output = [] for i in range(T): + tmp = [] for j in range(T): - out_ij = to_linear_operator( - torch.zeros(self.linear_operators[0][0].shape[0], rhs.linear_operators[0][0].shape[1]) - ) - for k in range(T): + tmp.append([]) + output.append(tmp) + for i in range(T): + for j in range(T): + out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] + for k in range(1, T): out_ij += self.linear_operators[i][k] @ rhs.linear_operators[k][j] output[i][j] = out_ij + return self.__class__(output) elif isinstance(rhs, Tensor): # Check both matrix dims divisible by T, - # reshape to (T, T, ), call .from_tensor - pass - - elif isinstance(rhs, LinearOperator): - pass - - else: - raise Exception("") - - return self.__class__(output) + # reshape to (T, T, ), call block multiplication + if rhs.size(0) % T == 0 and rhs.size(1) % T == 0: + # A is block [N * T, M * T] and B is a general tensor/operator of shape [O, P]. + # If O and P are both divisible by T, + # then interpret B as a [O//T * T, P//T * T] block matrix + O_T = rhs.size(0) // T + P_T = rhs.size(1) // T + rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) + rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) + rhs_op = BlockTensorLinearOperator.from_tensor(rhs_blocks, T) + return self._matmul(rhs_op) + + A = self.to_dense() + B = rhs.to_dense() + res = A @ B + return res def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: out = [] diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py index 8d9e9ebc..65ec8492 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_block_tensor_linear_operator.py @@ -10,7 +10,7 @@ # from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase -class TestBlockBlockSimple(BaseTestCase, unittest.TestCase): +class TestBlockTensorSimple(BaseTestCase, unittest.TestCase): def test_multiply(self): T = 2 N = 4 @@ -32,7 +32,7 @@ def test_multiply(self): self.assertAllClose(A_dense, A_blo.to_dense()) self.assertAllClose(B_dense, B_blo.to_dense()) - # Try to convert dense to block + # Convert dense format back to blocks and compare Ne = A_dense.size(0) // T Me = A_dense.size(1) // T A_blocks_est = A_dense.reshape(T, Ne, T, Me) @@ -40,14 +40,14 @@ def test_multiply(self): self.assertAllClose(A, A_blocks_est) # Check Tensor multiplication - # res_tensor_AB = A_blo._matmul(B_dense) - # res_tensor_dense_AB = res_tensor_AB.to_dense() - # self.assertAllClose(res_dense_AB, res_tensor_dense_AB) + res_tensor_AB = A_blo._matmul(B_dense) + res_tensor_dense_AB = res_tensor_AB.to_dense() + self.assertAllClose(res_dense_AB, res_tensor_dense_AB) rem = """ -class TestBlockBlockLinearOperator(LinearOperatorTestCase, unittest.TestCase): +class TestBlockTensorLinearOperator(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = False T = 2 From 2cf2db0e0148521409cef4ed4c6c0e977b6b168a Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 30 May 2023 18:19:03 -0700 Subject: [PATCH 04/23] Add simple property implementations --- .../operators/block_tensor_linear_operator.py | 42 +++++++++++++++---- .../test_block_tensor_linear_operator.py | 3 +- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index 80dc2d92..676f0d9a 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Tuple, Union import torch from jaxtyping import Float @@ -10,7 +10,8 @@ class BlockTensorLinearOperator(LinearOperator): def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: - assert len(linear_operators) > 0, "must have nested list" + assert isinstance(linear_operators, list) + assert len(linear_operators) > 0, "must have non-empty list" assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" super().__init__(linear_operators) @@ -20,6 +21,17 @@ def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: self.block_rows = linear_operators[0][0].shape[0] self.block_cols = linear_operators[0][0].shape[1] + @staticmethod + def square_ops(T): + """Return an empty (square) list of operators of shape TxT""" + ops = [] + for i in range(T): + tmp = [] + for j in range(T): + tmp.append([]) + ops.append(tmp) + return ops + def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], @@ -30,12 +42,7 @@ def _matmul( # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: - output = [] - for i in range(T): - tmp = [] - for j in range(T): - tmp.append([]) - output.append(tmp) + output = BlockTensorLinearOperator.square_ops(T) for i in range(T): for j in range(T): out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] @@ -57,6 +64,7 @@ def _matmul( rhs_op = BlockTensorLinearOperator.from_tensor(rhs_blocks, T) return self._matmul(rhs_op) + # Failover implementation. Convert to dense and multiply matricies A = self.to_dense() B = rhs.to_dense() res = A @ B @@ -75,6 +83,24 @@ def _size(self) -> torch.Size: sz = self.linear_operators[0][0].size() return torch.Size([self.num_tasks * sz[0], self.num_tasks * sz[1]]) + @property + def dtype(self) -> Optional[torch.dtype]: + return self.linear_operators[0][0].dtype + + @property + def device(self) -> Optional[torch.device]: + return self.linear_operators[0][0].device + + def representation(self) -> Tuple[torch.Tensor, ...]: + """ + Returns the Tensors that are used to define the LinearOperator + """ + representation = [] + for op_row in self.linear_operators: + for op in op_row: + representation += tuple(op.representation()) + return tuple(representation) + def _diag(self): out = [] for i in range(self.num_tasks): diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py index 65ec8492..94bddbf4 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_block_tensor_linear_operator.py @@ -46,7 +46,6 @@ def test_multiply(self): rem = """ - class TestBlockTensorLinearOperator(LinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = False @@ -61,7 +60,7 @@ class TestBlockTensorLinearOperator(LinearOperatorTestCase, unittest.TestCase): # A = torch.randn(T, T, N, M) # Need to make something +ve definite def create_linear_op(self): - A_blo = BlockBLockLinearOperator.from_tensor(self.A_blocks) + A_blo = BlockTensorLinearOperator.from_tensor(self.A_blocks, self.T) return A_blo def evaluate_linear_op(self, linear_op): From 61d9b8a084e962a2a2ba3f8b4e1644ac24dd4a5c Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Wed, 31 May 2023 13:42:41 -0700 Subject: [PATCH 05/23] Upgrade linear operator to add block / sparse test --- .../operators/block_tensor_linear_operator.py | 11 ++++- .../test_block_tensor_linear_operator.py | 48 +++++++++++++++++-- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index 676f0d9a..c27e3efa 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -5,7 +5,8 @@ from torch import Tensor from ._linear_operator import LinearOperator -from .dense_linear_operator import to_linear_operator +from .dense_linear_operator import DenseLinearOperator +from .zero_linear_operator import ZeroLinearOperator class BlockTensorLinearOperator(LinearOperator): @@ -113,7 +114,13 @@ def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[Line @classmethod def from_tensor(cls, tensor: Tensor, num_tasks: int): + def tensor_to_linear_op(t): + if torch.count_nonzero(t) > 0: + return DenseLinearOperator(t) + return ZeroLinearOperator(*t.size(), dtype=t.dtype, device=t.device) + linear_ops = [ - [to_linear_operator(t[0]) for t in list(torch.tensor_split(tensor[i], num_tasks))] for i in range(num_tasks) + [tensor_to_linear_op(t[0]) for t in list(torch.tensor_split(tensor[i], num_tasks))] + for i in range(num_tasks) ] return cls(linear_ops) diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py index 94bddbf4..029a5f97 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_block_tensor_linear_operator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 - +import itertools import unittest import torch @@ -11,6 +11,13 @@ class TestBlockTensorSimple(BaseTestCase, unittest.TestCase): + def dense_to_4d(self, A_dense, T): + Ne = A_dense.size(0) // T + Me = A_dense.size(1) // T + A_blocks_est = A_dense.reshape(T, Ne, T, Me) + A_blocks_est = A_blocks_est.permute(0, 2, 1, 3) + return A_blocks_est + def test_multiply(self): T = 2 N = 4 @@ -33,10 +40,7 @@ def test_multiply(self): self.assertAllClose(B_dense, B_blo.to_dense()) # Convert dense format back to blocks and compare - Ne = A_dense.size(0) // T - Me = A_dense.size(1) // T - A_blocks_est = A_dense.reshape(T, Ne, T, Me) - A_blocks_est = A_blocks_est.permute(0, 2, 1, 3) + A_blocks_est = self.dense_to_4d(A_dense, T) self.assertAllClose(A, A_blocks_est) # Check Tensor multiplication @@ -44,6 +48,40 @@ def test_multiply(self): res_tensor_dense_AB = res_tensor_AB.to_dense() self.assertAllClose(res_dense_AB, res_tensor_dense_AB) + def test_sparse_multiply(self): + T, N, M = 2, 4, 3 + As = [torch.rand(N, M) for _ in range(T)] + Bs = [[torch.rand(M, M) for _ in range(T)] for _ in range(T)] + Cs = [torch.rand(N, N) for _ in range(T)] + # L = torch.rand(T, T) + + A_dense = torch.zeros((N * T, M * T)) # BlockDiag (non-square) + B_dense = torch.zeros((M * T, M * T)) # Dense + C_dense = torch.zeros((N * T, N * T)) # BlockDiag + # L_dense = torch.kron(L, torch.eye(N)) # Kroneker + + for t in range(T): + A_dense[N * t : N * (t + 1), M * t : M * (t + 1)] = As[t] + C_dense[N * t : N * (t + 1), N * t : N * (t + 1)] = Cs[t] + + for t1, t2 in itertools.product(range(T), range(T)): + B_dense[M * t1 : M * (t1 + 1), M * t2 : M * (t2 + 1)] = Bs[t1][t2] + + # Convert dense formats to blocks + A = self.dense_to_4d(A_dense, T) + B = self.dense_to_4d(B_dense, T) + + # A_blo will contain dense operators along the diagonal + Zero operators off diagonal + A_blo = BlockTensorLinearOperator.from_tensor(A, T) + B_blo = BlockTensorLinearOperator.from_tensor(B, T) + res_AB = A_blo._matmul(B_blo) + res_dense_AB = res_AB.to_dense() + + expected = A_dense @ B_dense + self.assertAllClose(res_dense_AB, expected) + self.assertAllClose(A_dense, A_blo.to_dense()) + self.assertAllClose(B_dense, B_blo.to_dense()) + rem = """ class TestBlockTensorLinearOperator(LinearOperatorTestCase, unittest.TestCase): From bbc12ea2a8de38d7cc32bacda1e7f57f77749209 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Wed, 31 May 2023 16:11:40 -0700 Subject: [PATCH 06/23] Add and document core test cases --- .../operators/block_tensor_linear_operator.py | 25 +- .../test/linear_operator_core_test_case.py | 307 ++++++++++++++++++ .../test_block_tensor_linear_operator.py | 7 +- 3 files changed, 331 insertions(+), 8 deletions(-) create mode 100644 linear_operator/test/linear_operator_core_test_case.py diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index c27e3efa..ffd98234 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -4,7 +4,7 @@ from jaxtyping import Float from torch import Tensor -from ._linear_operator import LinearOperator +from ._linear_operator import IndexType, LinearOperator from .dense_linear_operator import DenseLinearOperator from .zero_linear_operator import ZeroLinearOperator @@ -51,7 +51,7 @@ def _matmul( out_ij += self.linear_operators[i][k] @ rhs.linear_operators[k][j] output[i][j] = out_ij return self.__class__(output) - elif isinstance(rhs, Tensor): + elif isinstance(rhs, Tensor) and rhs.ndim == 2: # Check both matrix dims divisible by T, # reshape to (T, T, ), call block multiplication if rhs.size(0) % T == 0 and rhs.size(1) % T == 0: @@ -71,6 +71,12 @@ def _matmul( res = A @ B return res + def matmul( + self: Float[LinearOperator, "*batch M N"], + other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], + ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + return self._matmul(other) + def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: out = [] for i in range(self.num_tasks): @@ -110,7 +116,20 @@ def _diag(self): return torch.concat(out, axis=1) def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: - return self # Diagonal matrices are symmetric + out = [] + for i in range(self.num_tasks): + rows = [] + for j in range(self.num_tasks): + rows.append(self.linear_operators[j][i].mT) + out.append(rows) + return BlockTensorLinearOperator(out) + + def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: + # Perform the __getitem__ + # TODO make this faster, see block_linear_operator + tsr = self.to_dense() + res = tsr[(*batch_indices, row_index, col_index)] + return DenseLinearOperator(res) @classmethod def from_tensor(cls, tensor: Tensor, num_tasks: int): diff --git a/linear_operator/test/linear_operator_core_test_case.py b/linear_operator/test/linear_operator_core_test_case.py new file mode 100644 index 00000000..014b04cd --- /dev/null +++ b/linear_operator/test/linear_operator_core_test_case.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 + +from abc import abstractmethod + +import torch + +import linear_operator +from linear_operator.operators import DiagLinearOperator, to_dense +from .base_test_case import BaseTestCase + +rem = """ +In code, a LinearOperator is a class that + +specifies the tensor(s) needed to define the LinearOperator, +specifies a _matmul function (how the LinearOperator is applied to a vector), +specifies a _size function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and +specifies a _transpose_nonbatch function (the adjoint of the LinearOperator). +(optionally) defines other functions (e.g. logdet, eigh, etc.) to accelerate computations for which efficient +sturcture-exploiting routines exist. +""" + + +class CoreLinearOperatorTestCase(BaseTestCase): + """Test the core operations for a LinearOperator""" + + tolerances = { + "matmul": {"rtol": 1e-3}, + "transpose": {"rtol": 1e-4, "atol": 1e-5}, + } + + @abstractmethod + def create_linear_op(self): + raise NotImplementedError() + + @abstractmethod + def evaluate_linear_op(self): + raise NotImplementedError() + + def _test_matmul(self, rhs): + linear_op = self.create_linear_op().detach().requires_grad_(True) + linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) + evaluated = self.evaluate_linear_op(linear_op_copy) + rhs_evaluated = to_dense(rhs) + + # Test operator + res = linear_op @ rhs + actual = evaluated.matmul(rhs_evaluated) + res_evaluated = to_dense(res) + self.assertAllClose(res_evaluated, actual) + + # Test __torch_function__ + res = torch.matmul(linear_op, rhs) + actual = evaluated.matmul(rhs) + self.assertAllClose(to_dense(res), actual) + + def test_transpose_nonbatch(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + res = linear_op._transpose_nonbatch() + actual = evaluated.mT + res_evaluated = to_dense(res) + self.assertAllClose(res_evaluated, actual, **self.tolerances["transpose"]) + + def _test_rmatmul(self, lhs): + # Note. transpose_nonbatch is tested implicitly here because + # the base linear operator class defines + # def rmatmul(other): + # return self.mT.matmul(other.mT).mT + linear_op = self.create_linear_op().detach().requires_grad_(True) + linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) + evaluated = self.evaluate_linear_op(linear_op_copy) + + # Test operator + res = lhs @ linear_op + res_evaluated = to_dense(res) + actual = lhs @ evaluated + self.assertAllClose(res_evaluated, actual) + + # Test __torch_function__ + res = torch.matmul(lhs, linear_op) + res_evaluated = to_dense(res) + actual = torch.matmul(lhs, evaluated) + self.assertAllClose(res_evaluated, actual) + + def test_add(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + a = (linear_op + rhs).to_dense() + b = evaluated + rhs + self.assertAllClose(a, b) + self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + self.assertAllClose((rhs + linear_op).to_dense(), evaluated + rhs) + # Test __torch_function__ functionality + self.assertAllClose(torch.add(linear_op, rhs).to_dense(), evaluated + rhs) + self.assertAllClose(torch.add(rhs, linear_op).to_dense(), evaluated + rhs) + + rhs = torch.randn(linear_op.matrix_shape) + self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + + self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) + + def test_matmul_vec(self): + linear_op = self.create_linear_op() + + # We skip this test if we're dealing with batch LinearOperators + # They shouldn't multiply by a vec + if linear_op.ndimension() > 2: + return + + rhs = torch.randn(linear_op.size(-1)) + return self._test_matmul(rhs) + + def test_constant_mul(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Test operator functionality + self.assertAllClose((linear_op * 5.0).to_dense(), evaluated * 5.0) + self.assertAllClose((linear_op * torch.tensor(5.0)).to_dense(), evaluated * 5.0) + self.assertAllClose((5.0 * linear_op).to_dense(), evaluated * 5.0) + self.assertAllClose((torch.tensor(5.0) * linear_op).to_dense(), evaluated * 5.0) + + # Test __torch_function__ functionality + self.assertAllClose(torch.mul(linear_op, torch.tensor(5.0)).to_dense(), evaluated * 5.0) + self.assertAllClose(torch.mul(torch.tensor(5.0), linear_op).to_dense(), evaluated * 5.0) + + def test_constant_mul_neg(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + self.assertAllClose((linear_op * -5.0).to_dense(), evaluated * -5.0) + + def test_constant_div(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Test operator functionality + self.assertAllClose((linear_op / 5.0).to_dense(), evaluated / 5.0) + self.assertAllClose((linear_op / torch.tensor(5.0)).to_dense(), evaluated / 5.0) + + # Test __torch_function__ functionality + self.assertAllClose(torch.div(linear_op, torch.tensor(5.0)).to_dense(), evaluated / 5.0) + + def test_to_dense(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + self.assertAllClose(linear_op.to_dense(), evaluated) + + def test_getitem(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + res = linear_op[1] + actual = evaluated[1] + self.assertAllClose(res, actual) + res = linear_op[0:2].to_dense() + actual = evaluated[0:2] + self.assertAllClose(res, actual) + res = linear_op[:, 0:2].to_dense() + actual = evaluated[:, 0:2] + self.assertAllClose(res, actual) + res = linear_op[0:2, :].to_dense() + actual = evaluated[0:2, :] + self.assertAllClose(res, actual) + res = linear_op[..., 0:2].to_dense() + actual = evaluated[..., 0:2] + self.assertAllClose(res, actual) + res = linear_op[0:2, ...].to_dense() + actual = evaluated[0:2, ...] + self.assertAllClose(res, actual) + res = linear_op[..., 0:2, 2] + actual = evaluated[..., 0:2, 2] + self.assertAllClose(res, actual) + res = linear_op[0:2, ..., 2] + actual = evaluated[0:2, ..., 2] + self.assertAllClose(res, actual) + + def test_getitem_tensor_index(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None)) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (slice(None, None, None), torch.tensor([0, 0, 1, 2])) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (torch.tensor([0, 0, 1, 2]), Ellipsis) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (Ellipsis, torch.tensor([0, 0, 1, 2])) + res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] + self.assertAllClose(res, actual) + index = (Ellipsis, torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + + def test_getitem_broadcasted_tensor_index(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + # Non-batch case + if linear_op.ndimension() == 2: + index = ( + torch.tensor([0, 0, 1, 2]).unsqueeze(-1), + torch.tensor([0, 1, 0, 2]).unsqueeze(-2), + ) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + index = ( + Ellipsis, + torch.tensor([0, 0, 1, 2]).unsqueeze(-2), + torch.tensor([0, 1, 0, 2]).unsqueeze(-1), + ) + res, actual = linear_op[index], evaluated[index] + self.assertAllClose(res, actual) + + def test_permute(self): + linear_op = self.create_linear_op() + if linear_op.dim() >= 4: + evaluated = self.evaluate_linear_op(linear_op) + dims = torch.randperm(linear_op.dim() - 2).tolist() + + # Call using __torch_function__ + res = torch.permute(linear_op, (*dims, -2, -1)).to_dense() + actual = torch.permute(evaluated, (*dims, -2, -1)) + self.assertAllClose(res, actual) + + # Call using method + res = linear_op.permute(*dims, -2, -1).to_dense() + actual = torch.permute(evaluated, (*dims, -2, -1)) + self.assertAllClose(res, actual) + + def test_rmatmul_vec(self): + linear_op = self.create_linear_op() + + # We skip this test if we're dealing with batch LinearOperators + # They shouldn't multiply by a vec + if linear_op.ndimension() > 2: + return + + lhs = torch.randn(linear_op.size(-2)) + return self._test_rmatmul(lhs) + + def test_matmul_matrix(self): + linear_op = self.create_linear_op() + rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4) + return self._test_matmul(rhs) + + def test_rmatmul_matrix(self): + linear_op = self.create_linear_op() + lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2)) + return self._test_rmatmul(lhs) + + def test_matmul_diag_matrix(self): + linear_op = self.create_linear_op() + diag = torch.rand(*linear_op.batch_shape, linear_op.size(-1)) + rhs = DiagLinearOperator(diag) + return self._test_matmul(rhs) + + def test_rsub(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + self.assertAllClose((rhs - linear_op).to_dense(), rhs - evaluated) + # Test __torch_function__ functionality + self.assertAllClose(torch.sub(rhs, linear_op).to_dense(), rhs - evaluated) + + def test_sub(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + rhs = torch.randn(linear_op.shape) + # Test operator functionality + self.assertAllClose((linear_op - rhs).to_dense(), evaluated - rhs) + # Test __torch_function__ functionality + self.assertAllClose(torch.sub(linear_op, rhs).to_dense(), evaluated - rhs) + + def test_sum(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + + self.assertAllClose(torch.sum(linear_op, -1), torch.sum(evaluated, -1)) + self.assertAllClose(torch.sum(linear_op, -2), torch.sum(evaluated, -2)) + if linear_op.ndimension() > 2: + self.assertAllClose(torch.sum(linear_op, -3).to_dense(), torch.sum(evaluated, -3)) + if linear_op.ndimension() > 3: + self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) + + def test_add_jitter(self): + linear_op = self.create_linear_op() + evaluated = self.evaluate_linear_op(linear_op) + res = linear_operator.add_jitter(linear_op, 0.4).to_dense() + actual = evaluated + torch.eye(evaluated.size(-1)).mul_(0.4) + self.assertAllClose(res, actual) diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py index 029a5f97..aa01a17d 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_block_tensor_linear_operator.py @@ -6,8 +6,7 @@ from linear_operator.operators import BlockTensorLinearOperator from linear_operator.test.base_test_case import BaseTestCase - -# from linear_operator.test.linear_operator_test_case import LinearOperatorTestCase +from linear_operator.test.linear_operator_core_test_case import CoreLinearOperatorTestCase class TestBlockTensorSimple(BaseTestCase, unittest.TestCase): @@ -83,8 +82,7 @@ def test_sparse_multiply(self): self.assertAllClose(B_dense, B_blo.to_dense()) -rem = """ -class TestBlockTensorLinearOperator(LinearOperatorTestCase, unittest.TestCase): +class TestLinearOperatorBlockTensorLinearOperator(CoreLinearOperatorTestCase, unittest.TestCase): seed = 0 should_test_sample = False T = 2 @@ -104,7 +102,6 @@ def create_linear_op(self): def evaluate_linear_op(self, linear_op): D = linear_op.to_dense() return D -""" if __name__ == "__main__": From bfc843a9bb9527a443db3e05add746dd8f51e43f Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Wed, 31 May 2023 16:12:56 -0700 Subject: [PATCH 07/23] Cleanup dead comments --- test/operators/test_block_tensor_linear_operator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_block_tensor_linear_operator.py index aa01a17d..76567a72 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_block_tensor_linear_operator.py @@ -87,14 +87,10 @@ class TestLinearOperatorBlockTensorLinearOperator(CoreLinearOperatorTestCase, un should_test_sample = False T = 2 N = M = 4 # Try a square for this set of tests - # N = 4 - # M = 3 A_dense = torch.eye(T * N) A_blocks = A_dense.reshape(T, N, T, M).permute(0, 2, 1, 3) - # A = torch.randn(T, T, N, M) # Need to make something +ve definite - def create_linear_op(self): A_blo = BlockTensorLinearOperator.from_tensor(self.A_blocks, self.T) return A_blo From 516b36915416cb617e005c407ca729fc523d00e2 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 15:41:49 -0700 Subject: [PATCH 08/23] Update linear_operator/operators/block_tensor_linear_operator.py Co-authored-by: Danny Friar --- linear_operator/operators/block_tensor_linear_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index ffd98234..d81af037 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -11,7 +11,7 @@ class BlockTensorLinearOperator(LinearOperator): def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: - assert isinstance(linear_operators, list) + assert isinstance(linear_operators, list), f"{self.__class__.__name__} expects a nested list of LinearOperators` assert len(linear_operators) > 0, "must have non-empty list" assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" From 9d12d7c58ba9cacc53971e9d3b33dc2a2b639ac4 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 15:45:18 -0700 Subject: [PATCH 09/23] Update linear_operator/operators/block_tensor_linear_operator.py Co-authored-by: Danny Friar --- linear_operator/operators/block_tensor_linear_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index d81af037..40c66b4b 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -132,7 +132,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return DenseLinearOperator(res) @classmethod - def from_tensor(cls, tensor: Tensor, num_tasks: int): + def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockTensorLinearOperator": def tensor_to_linear_op(t): if torch.count_nonzero(t) > 0: return DenseLinearOperator(t) From 56809b23b5a72ee773f6fc387ef12de9f56da8bb Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 16:10:56 -0700 Subject: [PATCH 10/23] Improve construction tests and types --- .../operators/block_tensor_linear_operator.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/block_tensor_linear_operator.py index 40c66b4b..1bd8b23d 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/block_tensor_linear_operator.py @@ -11,7 +11,7 @@ class BlockTensorLinearOperator(LinearOperator): def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: - assert isinstance(linear_operators, list), f"{self.__class__.__name__} expects a nested list of LinearOperators` + assert isinstance(linear_operators, list), f"{self.__class__.__name__} expects a nested list of LinearOperators" assert len(linear_operators) > 0, "must have non-empty list" assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" @@ -22,8 +22,19 @@ def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: self.block_rows = linear_operators[0][0].shape[0] self.block_cols = linear_operators[0][0].shape[1] + # Check that provided operators all have the same shape + T = self.num_tasks + for i in range(T): + for j in range(T): + assert ( + linear_operators[i][j].shape[0] == self.block_rows + ), "the number of rows much match for all linear operators" + assert ( + linear_operators[i][j].shape[1] == self.block_cols + ), "the number of columns much match for all linear operators" + @staticmethod - def square_ops(T): + def create_square_ops_output(T: int) -> List[List[LinearOperator]]: """Return an empty (square) list of operators of shape TxT""" ops = [] for i in range(T): @@ -43,7 +54,7 @@ def _matmul( # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: - output = BlockTensorLinearOperator.square_ops(T) + output = BlockTensorLinearOperator.create_square_ops_output(T) for i in range(T): for j in range(T): out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] @@ -111,6 +122,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]: def _diag(self): out = [] for i in range(self.num_tasks): + # The underlying operators will test if they are square diagonal = self.linear_operators[i][i].diagonal() out.append(diagonal) return torch.concat(out, axis=1) @@ -126,14 +138,13 @@ def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[Line def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: # Perform the __getitem__ - # TODO make this faster, see block_linear_operator tsr = self.to_dense() res = tsr[(*batch_indices, row_index, col_index)] return DenseLinearOperator(res) @classmethod def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockTensorLinearOperator": - def tensor_to_linear_op(t): + def tensor_to_linear_op(t: Tensor) -> LinearOperator: if torch.count_nonzero(t) > 0: return DenseLinearOperator(t) return ZeroLinearOperator(*t.size(), dtype=t.dtype, device=t.device) From 2f9b4b788a64fef1f4c30956afc0ca085691c1d6 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 16:22:10 -0700 Subject: [PATCH 11/23] Rename class to MatrixLinearOperator --- linear_operator/operators/__init__.py | 4 ++-- ...ear_operator.py => matrix_linear_operator.py} | 16 +++++++++++----- ...perator.py => test_matrix_linear_operator.py} | 12 ++++++------ 3 files changed, 19 insertions(+), 13 deletions(-) rename linear_operator/operators/{block_tensor_linear_operator.py => matrix_linear_operator.py} (93%) rename test/operators/{test_block_tensor_linear_operator.py => test_matrix_linear_operator.py} (89%) diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index 96ea8a61..bcf06184 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -6,7 +6,6 @@ from .block_diag_linear_operator import BlockDiagLinearOperator from .block_interleaved_linear_operator import BlockInterleavedLinearOperator from .block_linear_operator import BlockLinearOperator -from .block_tensor_linear_operator import BlockTensorLinearOperator from .cat_linear_operator import cat, CatLinearOperator from .chol_linear_operator import CholLinearOperator from .constant_mul_linear_operator import ConstantMulLinearOperator @@ -24,6 +23,7 @@ from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator from .low_rank_root_linear_operator import LowRankRootLinearOperator from .matmul_linear_operator import MatmulLinearOperator +from .matrix_linear_operator import MatrixLinearOperator from .mul_linear_operator import MulLinearOperator from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator from .psd_sum_linear_operator import PsdSumLinearOperator @@ -45,7 +45,7 @@ "BlockLinearOperator", "BlockDiagLinearOperator", "BlockInterleavedLinearOperator", - "BlockTensorLinearOperator", + "MatrixLinearOperator", "CatLinearOperator", "CholLinearOperator", "ConstantDiagLinearOperator", diff --git a/linear_operator/operators/block_tensor_linear_operator.py b/linear_operator/operators/matrix_linear_operator.py similarity index 93% rename from linear_operator/operators/block_tensor_linear_operator.py rename to linear_operator/operators/matrix_linear_operator.py index 1bd8b23d..5b297db9 100644 --- a/linear_operator/operators/block_tensor_linear_operator.py +++ b/linear_operator/operators/matrix_linear_operator.py @@ -9,7 +9,13 @@ from .zero_linear_operator import ZeroLinearOperator -class BlockTensorLinearOperator(LinearOperator): +class MatrixLinearOperator(LinearOperator): + """ + A TxT matrix of LinearOperators + + :param linear_operators: A TxT nested list of linear operators reprsenting a 2-D matrix + """ + def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: assert isinstance(linear_operators, list), f"{self.__class__.__name__} expects a nested list of LinearOperators" assert len(linear_operators) > 0, "must have non-empty list" @@ -54,7 +60,7 @@ def _matmul( # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: - output = BlockTensorLinearOperator.create_square_ops_output(T) + output = MatrixLinearOperator.create_square_ops_output(T) for i in range(T): for j in range(T): out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] @@ -73,7 +79,7 @@ def _matmul( P_T = rhs.size(1) // T rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) - rhs_op = BlockTensorLinearOperator.from_tensor(rhs_blocks, T) + rhs_op = MatrixLinearOperator.from_tensor(rhs_blocks, T) return self._matmul(rhs_op) # Failover implementation. Convert to dense and multiply matricies @@ -134,7 +140,7 @@ def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[Line for j in range(self.num_tasks): rows.append(self.linear_operators[j][i].mT) out.append(rows) - return BlockTensorLinearOperator(out) + return MatrixLinearOperator(out) def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: # Perform the __getitem__ @@ -143,7 +149,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return DenseLinearOperator(res) @classmethod - def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockTensorLinearOperator": + def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "MatrixLinearOperator": def tensor_to_linear_op(t: Tensor) -> LinearOperator: if torch.count_nonzero(t) > 0: return DenseLinearOperator(t) diff --git a/test/operators/test_block_tensor_linear_operator.py b/test/operators/test_matrix_linear_operator.py similarity index 89% rename from test/operators/test_block_tensor_linear_operator.py rename to test/operators/test_matrix_linear_operator.py index 76567a72..3a08845e 100644 --- a/test/operators/test_block_tensor_linear_operator.py +++ b/test/operators/test_matrix_linear_operator.py @@ -4,7 +4,7 @@ import torch -from linear_operator.operators import BlockTensorLinearOperator +from linear_operator.operators import MatrixLinearOperator from linear_operator.test.base_test_case import BaseTestCase from linear_operator.test.linear_operator_core_test_case import CoreLinearOperatorTestCase @@ -26,8 +26,8 @@ def test_multiply(self): A = torch.randn(T, T, N, M) B = torch.randn(T, T, M, K) - A_blo = BlockTensorLinearOperator.from_tensor(A, T) - B_blo = BlockTensorLinearOperator.from_tensor(B, T) + A_blo = MatrixLinearOperator.from_tensor(A, T) + B_blo = MatrixLinearOperator.from_tensor(B, T) res_AB = A_blo._matmul(B_blo) res_dense_AB = res_AB.to_dense() @@ -71,8 +71,8 @@ def test_sparse_multiply(self): B = self.dense_to_4d(B_dense, T) # A_blo will contain dense operators along the diagonal + Zero operators off diagonal - A_blo = BlockTensorLinearOperator.from_tensor(A, T) - B_blo = BlockTensorLinearOperator.from_tensor(B, T) + A_blo = MatrixLinearOperator.from_tensor(A, T) + B_blo = MatrixLinearOperator.from_tensor(B, T) res_AB = A_blo._matmul(B_blo) res_dense_AB = res_AB.to_dense() @@ -92,7 +92,7 @@ class TestLinearOperatorBlockTensorLinearOperator(CoreLinearOperatorTestCase, un A_blocks = A_dense.reshape(T, N, T, M).permute(0, 2, 1, 3) def create_linear_op(self): - A_blo = BlockTensorLinearOperator.from_tensor(self.A_blocks, self.T) + A_blo = MatrixLinearOperator.from_tensor(self.A_blocks, self.T) return A_blo def evaluate_linear_op(self, linear_op): From 7cbbc54a3d395c0b233eaaa5063f5a954f3b7868 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 16:34:59 -0700 Subject: [PATCH 12/23] Add parts omitted from base test case. Show them as commented out to make it clear what we are skipping. --- .../test/linear_operator_core_test_case.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/linear_operator/test/linear_operator_core_test_case.py b/linear_operator/test/linear_operator_core_test_case.py index 014b04cd..0b238810 100644 --- a/linear_operator/test/linear_operator_core_test_case.py +++ b/linear_operator/test/linear_operator_core_test_case.py @@ -48,6 +48,13 @@ def _test_matmul(self, rhs): res_evaluated = to_dense(res) self.assertAllClose(res_evaluated, actual) + # grad = torch.randn_like(res_evaluated) + # res_evaluated.backward(gradient=grad) + # actual.backward(gradient=grad) + # for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()): + # if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: + # self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) + # Test __torch_function__ res = torch.matmul(linear_op, rhs) actual = evaluated.matmul(rhs) @@ -83,6 +90,13 @@ def _test_rmatmul(self, lhs): actual = torch.matmul(lhs, evaluated) self.assertAllClose(res_evaluated, actual) + # grad = torch.randn_like(res) + # res.backward(gradient=grad) + # actual.backward(gradient=grad) + # for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()): + # if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: + # self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) + def test_add(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -101,6 +115,9 @@ def test_add(self): rhs = torch.randn(linear_op.matrix_shape) self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + # rhs = torch.randn(2, *linear_op.shape) + # self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) def test_matmul_vec(self): @@ -180,6 +197,37 @@ def test_getitem(self): actual = evaluated[0:2, ..., 2] self.assertAllClose(res, actual) + # # Batch case + # else: + # res = linear_op[1].to_dense() + # actual = evaluated[1] + # self.assertAllClose(res, actual) + # res = linear_op[0:2].to_dense() + # actual = evaluated[0:2] + # self.assertAllClose(res, actual) + # res = linear_op[:, 0:2].to_dense() + # actual = evaluated[:, 0:2] + # self.assertAllClose(res, actual) + # + # for batch_index in product([1, slice(0, 2, None)], repeat=(linear_op.dim() - 2)): + # res = linear_op.__getitem__((*batch_index, slice(0, 1, None), slice(0, 2, None))).to_dense() + # actual = evaluated.__getitem__((*batch_index, slice(0, 1, None), slice(0, 2, None))) + # self.assertAllClose(res, actual) + # res = linear_op.__getitem__((*batch_index, 1, slice(0, 2, None))) + # actual = evaluated.__getitem__((*batch_index, 1, slice(0, 2, None))) + # self.assertAllClose(res, actual) + # res = linear_op.__getitem__((*batch_index, slice(1, None, None), 2)) + # actual = evaluated.__getitem__((*batch_index, slice(1, None, None), 2)) + # self.assertAllClose(res, actual) + # + # # Ellipsis + # res = linear_op.__getitem__((Ellipsis, slice(1, None, None), 2)) + # actual = evaluated.__getitem__((Ellipsis, slice(1, None, None), 2)) + # self.assertAllClose(res, actual) + # res = linear_op.__getitem__((slice(1, None, None), Ellipsis, 2)) + # actual = evaluated.__getitem__((slice(1, None, None), Ellipsis, 2)) + # self.assertAllClose(res, actual) + def test_getitem_tensor_index(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -205,6 +253,53 @@ def test_getitem_tensor_index(self): res, actual = linear_op[index], evaluated[index] self.assertAllClose(res, actual) + # # Batch case + # else: + # for batch_index in product( + # [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], + # repeat=(linear_op.dim() - 2), + # ): + # index = ( + # *batch_index, + # torch.tensor([0, 1, 0, 2]), + # torch.tensor([1, 2, 0, 1]), + # ) + # res, actual = linear_op[index], evaluated[index] + # self.assertAllClose(res, actual) + # index = ( + # *batch_index, + # torch.tensor([0, 1, 0, 2]), + # slice(None, None, None), + # ) + # res, actual = ( + # linear_operator.to_dense(linear_op[index]), + # evaluated[index], + # ) + # self.assertAllClose(res, actual) + # index = ( + # *batch_index, + # slice(None, None, None), + # torch.tensor([0, 1, 2, 1]), + # ) + # res, actual = ( + # linear_operator.to_dense(linear_op[index]), + # evaluated[index], + # ) + # self.assertAllClose(res, actual) + # index = (*batch_index, slice(None, None, None), slice(None, None, None)) + # res, actual = linear_op[index].to_dense(), evaluated[index] + # self.assertAllClose(res, actual) + # + # # Ellipsis + # res = linear_op.__getitem__((Ellipsis, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))) + # actual = evaluated.__getitem__((Ellipsis, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))) + # self.assertAllClose(res, actual) + # res = linear_operator.to_dense( + # linear_op.__getitem__((torch.tensor([0, 1, 0, 1]), Ellipsis, torch.tensor([1, 2, 0, 1]))) + # ) + # actual = evaluated.__getitem__((torch.tensor([0, 1, 0, 1]), Ellipsis, torch.tensor([1, 2, 0, 1]))) + # self.assertAllClose(res, actual) + def test_getitem_broadcasted_tensor_index(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -225,6 +320,62 @@ def test_getitem_broadcasted_tensor_index(self): res, actual = linear_op[index], evaluated[index] self.assertAllClose(res, actual) + # # Batch case + # else: + # for batch_index in product( + # [torch.tensor([0, 1, 1]).view(-1, 1, 1), slice(None, None, None)], + # repeat=(linear_op.dim() - 2), + # ): + # index = ( + # *batch_index, + # torch.tensor([0, 1]).view(-1, 1), + # torch.tensor([1, 2, 0, 1]).view(1, -1), + # ) + # res, actual = linear_op[index], evaluated[index] + # self.assertAllClose(res, actual) + # res, actual = ( + # linear_operator.to_dense(linear_op[index]), + # evaluated[index], + # ) + # self.assertAllClose(res, actual) + # index = (*batch_index, slice(None, None, None), slice(None, None, None)) + # res, actual = linear_op[index].to_dense(), evaluated[index] + # self.assertAllClose(res, actual) + # + # # Ellipsis + # res = linear_op.__getitem__( + # ( + # Ellipsis, + # torch.tensor([0, 1, 0]).view(-1, 1, 1), + # torch.tensor([1, 2, 0, 1]).view(1, 1, -1), + # ) + # ) + # actual = evaluated.__getitem__( + # ( + # Ellipsis, + # torch.tensor([0, 1, 0]).view(-1, 1, 1), + # torch.tensor([1, 2, 0, 1]).view(1, 1, -1), + # ) + # ) + # self.assertAllClose(res, actual) + # res = linear_operator.to_dense( + # linear_op.__getitem__( + # ( + # torch.tensor([0, 1, 0]).view(1, -1), + # Ellipsis, + # torch.tensor([1, 2, 0, 1]).view(-1, 1), + # ) + # ) + # ) + # actual = evaluated.__getitem__( + # ( + # torch.tensor([0, 1, 0]).view(1, -1), + # Ellipsis, + # torch.tensor([1, 2, 0, 1]).view(-1, 1), + # ) + # ) + # self.assertAllClose(res, actual) + def test_permute(self): linear_op = self.create_linear_op() if linear_op.dim() >= 4: @@ -268,6 +419,44 @@ def test_matmul_diag_matrix(self): rhs = DiagLinearOperator(diag) return self._test_matmul(rhs) + # def test_matmul_matrix_broadcast(self): + # linear_op = self.create_linear_op() + # + # # Right hand size has one more batch dimension + # batch_shape = torch.Size((3, *linear_op.batch_shape)) + # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) + # self._test_matmul(rhs) + # + # if linear_op.ndimension() > 2: + # # Right hand size has one fewer batch dimension + # batch_shape = torch.Size(linear_op.batch_shape[1:]) + # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) + # self._test_matmul(rhs) + # + # # Right hand size has a singleton dimension + # batch_shape = torch.Size((*linear_op.batch_shape[:-1], 1)) + # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) + # self._test_matmul(rhs) + # + # def test_rmatmul_matrix_broadcast(self): + # linear_op = self.create_linear_op() + # + # # Left hand size has one more batch dimension + # batch_shape = torch.Size((3, *linear_op.batch_shape)) + # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) + # self._test_rmatmul(lhs) + # + # if linear_op.ndimension() > 2: + # # Left hand size has one fewer batch dimension + # batch_shape = torch.Size(linear_op.batch_shape[1:]) + # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) + # self._test_rmatmul(lhs) + # + # # Left hand size has a singleton dimension + # batch_shape = torch.Size((*linear_op.batch_shape[:-1], 1)) + # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) + # self._test_rmatmul(lhs) + def test_rsub(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -299,6 +488,26 @@ def test_sum(self): if linear_op.ndimension() > 3: self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) + # def test_squeeze_unsqueeze(self): + # linear_operator = self.create_linear_op() + # evaluated = self.evaluate_linear_op(linear_operator) + # + # unsqueezed = torch.unsqueeze(linear_operator, -3) + # self.assertAllClose(unsqueezed.to_dense(), evaluated.unsqueeze(-3)) + # + # squeezed = torch.squeeze(unsqueezed, -3) + # self.assertAllClose(squeezed.to_dense(), evaluated) + # + # def test_transpose_batch(self): + # linear_op = self.create_linear_op() + # evaluated = self.evaluate_linear_op(linear_op) + # + # if linear_op.dim() >= 4: + # for i, j in combinations(range(linear_op.dim() - 2), 2): + # res = torch.transpose(linear_op, i, j).to_dense() + # actual = torch.transpose(evaluated, i, j) + # self.assertAllClose(res, actual, **self.tolerances["transpose"]) + def test_add_jitter(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) From 32e9a52016cef054e3a472c3f449ab8671403de1 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 16:51:33 -0700 Subject: [PATCH 13/23] Rename class to BlockMatrixLinearOperator --- linear_operator/operators/__init__.py | 4 ++-- ...tor.py => block_matrix_linear_operator.py} | 20 ++++++++++++------- ...y => test_block_matrix_linear_operator.py} | 12 +++++------ 3 files changed, 21 insertions(+), 15 deletions(-) rename linear_operator/operators/{matrix_linear_operator.py => block_matrix_linear_operator.py} (89%) rename test/operators/{test_matrix_linear_operator.py => test_block_matrix_linear_operator.py} (89%) diff --git a/linear_operator/operators/__init__.py b/linear_operator/operators/__init__.py index bcf06184..9d5862a9 100644 --- a/linear_operator/operators/__init__.py +++ b/linear_operator/operators/__init__.py @@ -6,6 +6,7 @@ from .block_diag_linear_operator import BlockDiagLinearOperator from .block_interleaved_linear_operator import BlockInterleavedLinearOperator from .block_linear_operator import BlockLinearOperator +from .block_matrix_linear_operator import BlockMatrixLinearOperator from .cat_linear_operator import cat, CatLinearOperator from .chol_linear_operator import CholLinearOperator from .constant_mul_linear_operator import ConstantMulLinearOperator @@ -23,7 +24,6 @@ from .low_rank_root_added_diag_linear_operator import LowRankRootAddedDiagLinearOperator from .low_rank_root_linear_operator import LowRankRootLinearOperator from .matmul_linear_operator import MatmulLinearOperator -from .matrix_linear_operator import MatrixLinearOperator from .mul_linear_operator import MulLinearOperator from .permutation_linear_operator import PermutationLinearOperator, TransposePermutationLinearOperator from .psd_sum_linear_operator import PsdSumLinearOperator @@ -45,7 +45,7 @@ "BlockLinearOperator", "BlockDiagLinearOperator", "BlockInterleavedLinearOperator", - "MatrixLinearOperator", + "BlockMatrixLinearOperator", "CatLinearOperator", "CholLinearOperator", "ConstantDiagLinearOperator", diff --git a/linear_operator/operators/matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py similarity index 89% rename from linear_operator/operators/matrix_linear_operator.py rename to linear_operator/operators/block_matrix_linear_operator.py index 5b297db9..2d2be435 100644 --- a/linear_operator/operators/matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -9,11 +9,17 @@ from .zero_linear_operator import ZeroLinearOperator -class MatrixLinearOperator(LinearOperator): +class BlockMatrixLinearOperator(LinearOperator): """ - A TxT matrix of LinearOperators + A TxT block matrix of LinearOperators. - :param linear_operators: A TxT nested list of linear operators reprsenting a 2-D matrix + Idea. Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. + + Implementation. A block linear operator class that can keep track of the [T, T] block structure, + represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as + the appropriate linear operators on the blocks. + + :param linear_operators: A TxT nested list of linear operators representing a 2-D matrix """ def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: @@ -60,7 +66,7 @@ def _matmul( # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: - output = MatrixLinearOperator.create_square_ops_output(T) + output = BlockMatrixLinearOperator.create_square_ops_output(T) for i in range(T): for j in range(T): out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] @@ -79,7 +85,7 @@ def _matmul( P_T = rhs.size(1) // T rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) - rhs_op = MatrixLinearOperator.from_tensor(rhs_blocks, T) + rhs_op = BlockMatrixLinearOperator.from_tensor(rhs_blocks, T) return self._matmul(rhs_op) # Failover implementation. Convert to dense and multiply matricies @@ -140,7 +146,7 @@ def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[Line for j in range(self.num_tasks): rows.append(self.linear_operators[j][i].mT) out.append(rows) - return MatrixLinearOperator(out) + return BlockMatrixLinearOperator(out) def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: # Perform the __getitem__ @@ -149,7 +155,7 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I return DenseLinearOperator(res) @classmethod - def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "MatrixLinearOperator": + def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockMatrixLinearOperator": def tensor_to_linear_op(t: Tensor) -> LinearOperator: if torch.count_nonzero(t) > 0: return DenseLinearOperator(t) diff --git a/test/operators/test_matrix_linear_operator.py b/test/operators/test_block_matrix_linear_operator.py similarity index 89% rename from test/operators/test_matrix_linear_operator.py rename to test/operators/test_block_matrix_linear_operator.py index 3a08845e..ca8cdc16 100644 --- a/test/operators/test_matrix_linear_operator.py +++ b/test/operators/test_block_matrix_linear_operator.py @@ -4,7 +4,7 @@ import torch -from linear_operator.operators import MatrixLinearOperator +from linear_operator.operators import BlockMatrixLinearOperator from linear_operator.test.base_test_case import BaseTestCase from linear_operator.test.linear_operator_core_test_case import CoreLinearOperatorTestCase @@ -26,8 +26,8 @@ def test_multiply(self): A = torch.randn(T, T, N, M) B = torch.randn(T, T, M, K) - A_blo = MatrixLinearOperator.from_tensor(A, T) - B_blo = MatrixLinearOperator.from_tensor(B, T) + A_blo = BlockMatrixLinearOperator.from_tensor(A, T) + B_blo = BlockMatrixLinearOperator.from_tensor(B, T) res_AB = A_blo._matmul(B_blo) res_dense_AB = res_AB.to_dense() @@ -71,8 +71,8 @@ def test_sparse_multiply(self): B = self.dense_to_4d(B_dense, T) # A_blo will contain dense operators along the diagonal + Zero operators off diagonal - A_blo = MatrixLinearOperator.from_tensor(A, T) - B_blo = MatrixLinearOperator.from_tensor(B, T) + A_blo = BlockMatrixLinearOperator.from_tensor(A, T) + B_blo = BlockMatrixLinearOperator.from_tensor(B, T) res_AB = A_blo._matmul(B_blo) res_dense_AB = res_AB.to_dense() @@ -92,7 +92,7 @@ class TestLinearOperatorBlockTensorLinearOperator(CoreLinearOperatorTestCase, un A_blocks = A_dense.reshape(T, N, T, M).permute(0, 2, 1, 3) def create_linear_op(self): - A_blo = MatrixLinearOperator.from_tensor(self.A_blocks, self.T) + A_blo = BlockMatrixLinearOperator.from_tensor(self.A_blocks, self.T) return A_blo def evaluate_linear_op(self, linear_op): From 30ad0eda57e0b52771b8a26cabf072d2eae3384e Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 16:54:27 -0700 Subject: [PATCH 14/23] Fix type signature --- linear_operator/operators/block_matrix_linear_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index 2d2be435..eb9d1d23 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -131,7 +131,7 @@ def representation(self) -> Tuple[torch.Tensor, ...]: representation += tuple(op.representation()) return tuple(representation) - def _diag(self): + def _diag(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: out = [] for i in range(self.num_tasks): # The underlying operators will test if they are square From 889ce0fd0a3e0f14f96ac584b3ff7ae7fd581b89 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Thu, 1 Jun 2023 17:08:59 -0700 Subject: [PATCH 15/23] Improve comments --- linear_operator/test/linear_operator_core_test_case.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/linear_operator/test/linear_operator_core_test_case.py b/linear_operator/test/linear_operator_core_test_case.py index 0b238810..99e2e638 100644 --- a/linear_operator/test/linear_operator_core_test_case.py +++ b/linear_operator/test/linear_operator_core_test_case.py @@ -8,8 +8,8 @@ from linear_operator.operators import DiagLinearOperator, to_dense from .base_test_case import BaseTestCase -rem = """ -In code, a LinearOperator is a class that +""" +From the project description, a LinearOperator is a class that specifies the tensor(s) needed to define the LinearOperator, specifies a _matmul function (how the LinearOperator is applied to a vector), @@ -17,6 +17,8 @@ specifies a _transpose_nonbatch function (the adjoint of the LinearOperator). (optionally) defines other functions (e.g. logdet, eigh, etc.) to accelerate computations for which efficient sturcture-exploiting routines exist. + +What follows is a class to test these core LinearOperator operations """ From c557aa3bc9b56c140cc37cd4fe26e8a53edcaa2e Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 18 Jul 2023 16:34:45 -0700 Subject: [PATCH 16/23] Refactor linear_operator_test_case.py into a set of core tests and more advanced tests. This allows us to create and test operators that only support core operations. --- .../test/linear_operator_core_test_case.py | 231 +----------------- .../test/linear_operator_test_case.py | 229 ++--------------- 2 files changed, 29 insertions(+), 431 deletions(-) diff --git a/linear_operator/test/linear_operator_core_test_case.py b/linear_operator/test/linear_operator_core_test_case.py index 99e2e638..72f6e3d4 100644 --- a/linear_operator/test/linear_operator_core_test_case.py +++ b/linear_operator/test/linear_operator_core_test_case.py @@ -9,16 +9,17 @@ from .base_test_case import BaseTestCase """ -From the project description, a LinearOperator is a class that +From the project description, a LinearOperator is a class that: -specifies the tensor(s) needed to define the LinearOperator, -specifies a _matmul function (how the LinearOperator is applied to a vector), -specifies a _size function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and -specifies a _transpose_nonbatch function (the adjoint of the LinearOperator). -(optionally) defines other functions (e.g. logdet, eigh, etc.) to accelerate computations for which efficient -sturcture-exploiting routines exist. +- specifies the tensor(s) needed to define the LinearOperator, +- specifies a _matmul function (how the LinearOperator is applied to a vector), +- specifies a _size function (how big is the LinearOperator if it is represented as a matrix, or batch of matrices), and +- specifies a _transpose_nonbatch function (the adjoint of the LinearOperator). +- (optionally) defines other functions (e.g. logdet, eigh, etc.) to accelerate computations for which efficient + sturcture-exploiting routines exist. -What follows is a class to test these core LinearOperator operations +What follows is a class to test these core LinearOperator operations. +Note that batch operations are excluded here since they are not part of the core definition. """ @@ -50,13 +51,6 @@ def _test_matmul(self, rhs): res_evaluated = to_dense(res) self.assertAllClose(res_evaluated, actual) - # grad = torch.randn_like(res_evaluated) - # res_evaluated.backward(gradient=grad) - # actual.backward(gradient=grad) - # for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()): - # if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: - # self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) - # Test __torch_function__ res = torch.matmul(linear_op, rhs) actual = evaluated.matmul(rhs) @@ -92,13 +86,6 @@ def _test_rmatmul(self, lhs): actual = torch.matmul(lhs, evaluated) self.assertAllClose(res_evaluated, actual) - # grad = torch.randn_like(res) - # res.backward(gradient=grad) - # actual.backward(gradient=grad) - # for arg, arg_copy in zip(linear_op.representation(), linear_op_copy.representation()): - # if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: - # self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) - def test_add(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -121,6 +108,7 @@ def test_add(self): # self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) + return linear_op, evaluated def test_matmul_vec(self): linear_op = self.create_linear_op() @@ -199,37 +187,6 @@ def test_getitem(self): actual = evaluated[0:2, ..., 2] self.assertAllClose(res, actual) - # # Batch case - # else: - # res = linear_op[1].to_dense() - # actual = evaluated[1] - # self.assertAllClose(res, actual) - # res = linear_op[0:2].to_dense() - # actual = evaluated[0:2] - # self.assertAllClose(res, actual) - # res = linear_op[:, 0:2].to_dense() - # actual = evaluated[:, 0:2] - # self.assertAllClose(res, actual) - # - # for batch_index in product([1, slice(0, 2, None)], repeat=(linear_op.dim() - 2)): - # res = linear_op.__getitem__((*batch_index, slice(0, 1, None), slice(0, 2, None))).to_dense() - # actual = evaluated.__getitem__((*batch_index, slice(0, 1, None), slice(0, 2, None))) - # self.assertAllClose(res, actual) - # res = linear_op.__getitem__((*batch_index, 1, slice(0, 2, None))) - # actual = evaluated.__getitem__((*batch_index, 1, slice(0, 2, None))) - # self.assertAllClose(res, actual) - # res = linear_op.__getitem__((*batch_index, slice(1, None, None), 2)) - # actual = evaluated.__getitem__((*batch_index, slice(1, None, None), 2)) - # self.assertAllClose(res, actual) - # - # # Ellipsis - # res = linear_op.__getitem__((Ellipsis, slice(1, None, None), 2)) - # actual = evaluated.__getitem__((Ellipsis, slice(1, None, None), 2)) - # self.assertAllClose(res, actual) - # res = linear_op.__getitem__((slice(1, None, None), Ellipsis, 2)) - # actual = evaluated.__getitem__((slice(1, None, None), Ellipsis, 2)) - # self.assertAllClose(res, actual) - def test_getitem_tensor_index(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -255,53 +212,6 @@ def test_getitem_tensor_index(self): res, actual = linear_op[index], evaluated[index] self.assertAllClose(res, actual) - # # Batch case - # else: - # for batch_index in product( - # [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], - # repeat=(linear_op.dim() - 2), - # ): - # index = ( - # *batch_index, - # torch.tensor([0, 1, 0, 2]), - # torch.tensor([1, 2, 0, 1]), - # ) - # res, actual = linear_op[index], evaluated[index] - # self.assertAllClose(res, actual) - # index = ( - # *batch_index, - # torch.tensor([0, 1, 0, 2]), - # slice(None, None, None), - # ) - # res, actual = ( - # linear_operator.to_dense(linear_op[index]), - # evaluated[index], - # ) - # self.assertAllClose(res, actual) - # index = ( - # *batch_index, - # slice(None, None, None), - # torch.tensor([0, 1, 2, 1]), - # ) - # res, actual = ( - # linear_operator.to_dense(linear_op[index]), - # evaluated[index], - # ) - # self.assertAllClose(res, actual) - # index = (*batch_index, slice(None, None, None), slice(None, None, None)) - # res, actual = linear_op[index].to_dense(), evaluated[index] - # self.assertAllClose(res, actual) - # - # # Ellipsis - # res = linear_op.__getitem__((Ellipsis, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))) - # actual = evaluated.__getitem__((Ellipsis, torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1]))) - # self.assertAllClose(res, actual) - # res = linear_operator.to_dense( - # linear_op.__getitem__((torch.tensor([0, 1, 0, 1]), Ellipsis, torch.tensor([1, 2, 0, 1]))) - # ) - # actual = evaluated.__getitem__((torch.tensor([0, 1, 0, 1]), Ellipsis, torch.tensor([1, 2, 0, 1]))) - # self.assertAllClose(res, actual) - def test_getitem_broadcasted_tensor_index(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -322,62 +232,6 @@ def test_getitem_broadcasted_tensor_index(self): res, actual = linear_op[index], evaluated[index] self.assertAllClose(res, actual) - # # Batch case - # else: - # for batch_index in product( - # [torch.tensor([0, 1, 1]).view(-1, 1, 1), slice(None, None, None)], - # repeat=(linear_op.dim() - 2), - # ): - # index = ( - # *batch_index, - # torch.tensor([0, 1]).view(-1, 1), - # torch.tensor([1, 2, 0, 1]).view(1, -1), - # ) - # res, actual = linear_op[index], evaluated[index] - # self.assertAllClose(res, actual) - # res, actual = ( - # linear_operator.to_dense(linear_op[index]), - # evaluated[index], - # ) - # self.assertAllClose(res, actual) - # index = (*batch_index, slice(None, None, None), slice(None, None, None)) - # res, actual = linear_op[index].to_dense(), evaluated[index] - # self.assertAllClose(res, actual) - # - # # Ellipsis - # res = linear_op.__getitem__( - # ( - # Ellipsis, - # torch.tensor([0, 1, 0]).view(-1, 1, 1), - # torch.tensor([1, 2, 0, 1]).view(1, 1, -1), - # ) - # ) - # actual = evaluated.__getitem__( - # ( - # Ellipsis, - # torch.tensor([0, 1, 0]).view(-1, 1, 1), - # torch.tensor([1, 2, 0, 1]).view(1, 1, -1), - # ) - # ) - # self.assertAllClose(res, actual) - # res = linear_operator.to_dense( - # linear_op.__getitem__( - # ( - # torch.tensor([0, 1, 0]).view(1, -1), - # Ellipsis, - # torch.tensor([1, 2, 0, 1]).view(-1, 1), - # ) - # ) - # ) - # actual = evaluated.__getitem__( - # ( - # torch.tensor([0, 1, 0]).view(1, -1), - # Ellipsis, - # torch.tensor([1, 2, 0, 1]).view(-1, 1), - # ) - # ) - # self.assertAllClose(res, actual) - def test_permute(self): linear_op = self.create_linear_op() if linear_op.dim() >= 4: @@ -421,44 +275,6 @@ def test_matmul_diag_matrix(self): rhs = DiagLinearOperator(diag) return self._test_matmul(rhs) - # def test_matmul_matrix_broadcast(self): - # linear_op = self.create_linear_op() - # - # # Right hand size has one more batch dimension - # batch_shape = torch.Size((3, *linear_op.batch_shape)) - # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) - # self._test_matmul(rhs) - # - # if linear_op.ndimension() > 2: - # # Right hand size has one fewer batch dimension - # batch_shape = torch.Size(linear_op.batch_shape[1:]) - # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) - # self._test_matmul(rhs) - # - # # Right hand size has a singleton dimension - # batch_shape = torch.Size((*linear_op.batch_shape[:-1], 1)) - # rhs = torch.randn(*batch_shape, linear_op.size(-1), 4) - # self._test_matmul(rhs) - # - # def test_rmatmul_matrix_broadcast(self): - # linear_op = self.create_linear_op() - # - # # Left hand size has one more batch dimension - # batch_shape = torch.Size((3, *linear_op.batch_shape)) - # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) - # self._test_rmatmul(lhs) - # - # if linear_op.ndimension() > 2: - # # Left hand size has one fewer batch dimension - # batch_shape = torch.Size(linear_op.batch_shape[1:]) - # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) - # self._test_rmatmul(lhs) - # - # # Left hand size has a singleton dimension - # batch_shape = torch.Size((*linear_op.batch_shape[:-1], 1)) - # lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) - # self._test_rmatmul(lhs) - def test_rsub(self): linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) @@ -489,30 +305,3 @@ def test_sum(self): self.assertAllClose(torch.sum(linear_op, -3).to_dense(), torch.sum(evaluated, -3)) if linear_op.ndimension() > 3: self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) - - # def test_squeeze_unsqueeze(self): - # linear_operator = self.create_linear_op() - # evaluated = self.evaluate_linear_op(linear_operator) - # - # unsqueezed = torch.unsqueeze(linear_operator, -3) - # self.assertAllClose(unsqueezed.to_dense(), evaluated.unsqueeze(-3)) - # - # squeezed = torch.squeeze(unsqueezed, -3) - # self.assertAllClose(squeezed.to_dense(), evaluated) - # - # def test_transpose_batch(self): - # linear_op = self.create_linear_op() - # evaluated = self.evaluate_linear_op(linear_op) - # - # if linear_op.dim() >= 4: - # for i, j in combinations(range(linear_op.dim() - 2), 2): - # res = torch.transpose(linear_op, i, j).to_dense() - # actual = torch.transpose(evaluated, i, j) - # self.assertAllClose(res, actual, **self.tolerances["transpose"]) - - def test_add_jitter(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - res = linear_operator.add_jitter(linear_op, 0.4).to_dense() - actual = evaluated + torch.eye(evaluated.size(-1)).mul_(0.4) - self.assertAllClose(res, actual) diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 8f2b79ae..772ecc10 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -10,16 +10,16 @@ import torch import linear_operator -from linear_operator.operators import DenseLinearOperator, DiagLinearOperator, to_dense +from linear_operator.operators import DenseLinearOperator, to_dense from linear_operator.settings import linalg_dtypes from linear_operator.utils.errors import CachingError from linear_operator.utils.memoize import get_from_cache from ..utils.warnings import PerformanceWarning -from .base_test_case import BaseTestCase +from .linear_operator_core_test_case import CoreLinearOperatorTestCase -class RectangularLinearOperatorTestCase(BaseTestCase): +class RectangularLinearOperatorTestCase(CoreLinearOperatorTestCase): tolerances = { "matmul": {"rtol": 1e-3}, @@ -59,19 +59,25 @@ def _test_matmul(self, rhs): self.assertAllClose(to_dense(res), actual) def _test_rmatmul(self, lhs): + # Note. transpose_nonbatch is tested implicitly here because + # the base linear operator class defines + # def rmatmul(other): + # return self.mT.matmul(other.mT).mT linear_op = self.create_linear_op().detach().requires_grad_(True) linear_op_copy = torch.clone(linear_op).detach().requires_grad_(True) evaluated = self.evaluate_linear_op(linear_op_copy) # Test operator res = lhs @ linear_op + res_evaluated = to_dense(res) actual = lhs @ evaluated - self.assertAllClose(res, actual) + self.assertAllClose(res_evaluated, actual) # Test __torch_function__ res = torch.matmul(lhs, linear_op) + res_evaluated = to_dense(res) actual = torch.matmul(lhs, evaluated) - self.assertAllClose(res, actual) + self.assertAllClose(res_evaluated, actual) grad = torch.randn_like(res) res.backward(gradient=grad) @@ -81,107 +87,19 @@ def _test_rmatmul(self, lhs): self.assertAllClose(arg.grad, arg_copy.grad, **self.tolerances["matmul"]) def test_add(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - a = (linear_op + rhs).to_dense() - b = evaluated + rhs - self.assertAllClose(a, b) - self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) - self.assertAllClose((rhs + linear_op).to_dense(), evaluated + rhs) - # Test __torch_function__ functionality - self.assertAllClose(torch.add(linear_op, rhs).to_dense(), evaluated + rhs) - self.assertAllClose(torch.add(rhs, linear_op).to_dense(), evaluated + rhs) - - rhs = torch.randn(linear_op.matrix_shape) - self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) + linear_op, evaluated = super().test_add() + # Test a batch of 2 rhs = torch.randn(2, *linear_op.shape) self.assertAllClose((linear_op + rhs).to_dense(), evaluated + rhs) - self.assertAllClose((linear_op + linear_op).to_dense(), evaluated * 2) - - def test_matmul_vec(self): - linear_op = self.create_linear_op() - - # We skip this test if we're dealing with batch LinearOperators - # They shouldn't multiply by a vec - if linear_op.ndimension() > 2: - return - - rhs = torch.randn(linear_op.size(-1)) - return self._test_matmul(rhs) - - def test_constant_mul(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - # Test operator functionality - self.assertAllClose((linear_op * 5.0).to_dense(), evaluated * 5.0) - self.assertAllClose((linear_op * torch.tensor(5.0)).to_dense(), evaluated * 5.0) - self.assertAllClose((5.0 * linear_op).to_dense(), evaluated * 5.0) - self.assertAllClose((torch.tensor(5.0) * linear_op).to_dense(), evaluated * 5.0) - - # Test __torch_function__ functionality - self.assertAllClose(torch.mul(linear_op, torch.tensor(5.0)).to_dense(), evaluated * 5.0) - self.assertAllClose(torch.mul(torch.tensor(5.0), linear_op).to_dense(), evaluated * 5.0) - - def test_constant_mul_neg(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - self.assertAllClose((linear_op * -5.0).to_dense(), evaluated * -5.0) - - def test_constant_div(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - # Test operator functionality - self.assertAllClose((linear_op / 5.0).to_dense(), evaluated / 5.0) - self.assertAllClose((linear_op / torch.tensor(5.0)).to_dense(), evaluated / 5.0) - - # Test __torch_function__ functionality - self.assertAllClose(torch.div(linear_op, torch.tensor(5.0)).to_dense(), evaluated / 5.0) - - def test_to_dense(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - self.assertAllClose(linear_op.to_dense(), evaluated) - def test_getitem(self): + super().test_getitem() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - res = linear_op[1] - actual = evaluated[1] - self.assertAllClose(res, actual) - res = linear_op[0:2].to_dense() - actual = evaluated[0:2] - self.assertAllClose(res, actual) - res = linear_op[:, 0:2].to_dense() - actual = evaluated[:, 0:2] - self.assertAllClose(res, actual) - res = linear_op[0:2, :].to_dense() - actual = evaluated[0:2, :] - self.assertAllClose(res, actual) - res = linear_op[..., 0:2].to_dense() - actual = evaluated[..., 0:2] - self.assertAllClose(res, actual) - res = linear_op[0:2, ...].to_dense() - actual = evaluated[0:2, ...] - self.assertAllClose(res, actual) - res = linear_op[..., 0:2, 2] - actual = evaluated[..., 0:2, 2] - self.assertAllClose(res, actual) - res = linear_op[0:2, ..., 2] - actual = evaluated[0:2, ..., 2] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: res = linear_op[1].to_dense() actual = evaluated[1] self.assertAllClose(res, actual) @@ -212,32 +130,12 @@ def test_getitem(self): self.assertAllClose(res, actual) def test_getitem_tensor_index(self): + super().test_getitem_tensor_index() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - index = (torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - index = (torch.tensor([0, 0, 1, 2]), slice(None, None, None)) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (slice(None, None, None), torch.tensor([0, 0, 1, 2])) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (torch.tensor([0, 0, 1, 2]), Ellipsis) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (Ellipsis, torch.tensor([0, 0, 1, 2])) - res, actual = linear_operator.to_dense(linear_op[index]), evaluated[index] - self.assertAllClose(res, actual) - index = (Ellipsis, torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 0, 2])) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: for batch_index in product( [torch.tensor([0, 1, 1, 0]), slice(None, None, None)], repeat=(linear_op.dim() - 2), @@ -284,27 +182,12 @@ def test_getitem_tensor_index(self): self.assertAllClose(res, actual) def test_getitem_broadcasted_tensor_index(self): + super().test_getitem_broadcasted_tensor_index() linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) - # Non-batch case - if linear_op.ndimension() == 2: - index = ( - torch.tensor([0, 0, 1, 2]).unsqueeze(-1), - torch.tensor([0, 1, 0, 2]).unsqueeze(-2), - ) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - index = ( - Ellipsis, - torch.tensor([0, 0, 1, 2]).unsqueeze(-2), - torch.tensor([0, 1, 0, 2]).unsqueeze(-1), - ) - res, actual = linear_op[index], evaluated[index] - self.assertAllClose(res, actual) - # Batch case - else: + if linear_op.ndimension() != 2: for batch_index in product( [torch.tensor([0, 1, 1]).view(-1, 1, 1), slice(None, None, None)], repeat=(linear_op.dim() - 2), @@ -359,49 +242,6 @@ def test_getitem_broadcasted_tensor_index(self): ) self.assertAllClose(res, actual) - def test_permute(self): - linear_op = self.create_linear_op() - if linear_op.dim() >= 4: - evaluated = self.evaluate_linear_op(linear_op) - dims = torch.randperm(linear_op.dim() - 2).tolist() - - # Call using __torch_function__ - res = torch.permute(linear_op, (*dims, -2, -1)).to_dense() - actual = torch.permute(evaluated, (*dims, -2, -1)) - self.assertAllClose(res, actual) - - # Call using method - res = linear_op.permute(*dims, -2, -1).to_dense() - actual = torch.permute(evaluated, (*dims, -2, -1)) - self.assertAllClose(res, actual) - - def test_rmatmul_vec(self): - linear_op = self.create_linear_op() - - # We skip this test if we're dealing with batch LinearOperators - # They shouldn't multiply by a vec - if linear_op.ndimension() > 2: - return - - lhs = torch.randn(linear_op.size(-2)) - return self._test_rmatmul(lhs) - - def test_matmul_matrix(self): - linear_op = self.create_linear_op() - rhs = torch.randn(*linear_op.batch_shape, linear_op.size(-1), 4) - return self._test_matmul(rhs) - - def test_rmatmul_matrix(self): - linear_op = self.create_linear_op() - lhs = torch.randn(*linear_op.batch_shape, 4, linear_op.size(-2)) - return self._test_rmatmul(lhs) - - def test_matmul_diag_matrix(self): - linear_op = self.create_linear_op() - diag = torch.rand(*linear_op.batch_shape, linear_op.size(-1)) - rhs = DiagLinearOperator(diag) - return self._test_matmul(rhs) - def test_matmul_matrix_broadcast(self): linear_op = self.create_linear_op() @@ -440,37 +280,6 @@ def test_rmatmul_matrix_broadcast(self): lhs = torch.randn(*batch_shape, 4, linear_op.size(-2)) self._test_rmatmul(lhs) - def test_rsub(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - self.assertAllClose((rhs - linear_op).to_dense(), rhs - evaluated) - # Test __torch_function__ functionality - self.assertAllClose(torch.sub(rhs, linear_op).to_dense(), rhs - evaluated) - - def test_sub(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - rhs = torch.randn(linear_op.shape) - # Test operator functionality - self.assertAllClose((linear_op - rhs).to_dense(), evaluated - rhs) - # Test __torch_function__ functionality - self.assertAllClose(torch.sub(linear_op, rhs).to_dense(), evaluated - rhs) - - def test_sum(self): - linear_op = self.create_linear_op() - evaluated = self.evaluate_linear_op(linear_op) - - self.assertAllClose(torch.sum(linear_op, -1), torch.sum(evaluated, -1)) - self.assertAllClose(torch.sum(linear_op, -2), torch.sum(evaluated, -2)) - if linear_op.ndimension() > 2: - self.assertAllClose(torch.sum(linear_op, -3).to_dense(), torch.sum(evaluated, -3)) - if linear_op.ndimension() > 3: - self.assertAllClose(torch.sum(linear_op, -4).to_dense(), torch.sum(evaluated, -4)) - def test_squeeze_unsqueeze(self): linear_operator = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_operator) From 67b8fe9e50631c3382174b86b45e9b8d9e67ef12 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Tue, 18 Jul 2023 17:04:22 -0700 Subject: [PATCH 17/23] Incorporate review suggestions from Geoff Pleiss. --- .../operators/block_matrix_linear_operator.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index eb9d1d23..ad0518e8 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -4,6 +4,7 @@ from jaxtyping import Float from torch import Tensor +from .. import settings from ._linear_operator import IndexType, LinearOperator from .dense_linear_operator import DenseLinearOperator from .zero_linear_operator import ZeroLinearOperator @@ -23,9 +24,12 @@ class BlockMatrixLinearOperator(LinearOperator): """ def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: - assert isinstance(linear_operators, list), f"{self.__class__.__name__} expects a nested list of LinearOperators" - assert len(linear_operators) > 0, "must have non-empty list" - assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" + if settings.debug.on(): + assert hasattr( + linear_operators, "__iter__" + ), f"{self.__class__.__name__} expects a nested list (or iterable) of LinearOperators" + assert len(linear_operators) > 0, "must have non-empty list" + assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" super().__init__(linear_operators) @@ -91,6 +95,10 @@ def _matmul( # Failover implementation. Convert to dense and multiply matricies A = self.to_dense() B = rhs.to_dense() + + # Batch logic is not supported for now + assert B.ndim <= 2 + res = A @ B return res From 75b565a10b1fef4fc4573dea93c00315224ce00b Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 27 Jul 2023 10:44:38 -0700 Subject: [PATCH 18/23] Add comment explaining matmul override. --- linear_operator/operators/block_matrix_linear_operator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index ad0518e8..52800c98 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -106,6 +106,8 @@ def matmul( self: Float[LinearOperator, "*batch M N"], other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + # The base method wants to perform a matmul via broadcasting and a + # representation tree which this operator doesn't support. return self._matmul(other) def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: From ff0b6a2b3965b20203ff84810749cbb63ddb03a5 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss Date: Fri, 2 Jun 2023 20:45:31 +0000 Subject: [PATCH 19/23] Add jaxtyping requirement for conda --- .conda/meta.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 40a5696e..068a519f 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -18,6 +18,8 @@ requirements: run: - pytorch>=1.11 - scipy + - jaxtyping>=0.2.9 + - typeguard~=2.13.3 test: imports: From 58e868614ab03cef48f4b9f9190f6a10f789e1e9 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 27 Jul 2023 17:55:58 +0000 Subject: [PATCH 20/23] Fix linter --- linear_operator/operators/block_matrix_linear_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index 52800c98..2ca60ed0 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -106,7 +106,7 @@ def matmul( self: Float[LinearOperator, "*batch M N"], other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: - # The base method wants to perform a matmul via broadcasting and a + # The base method wants to perform a matmul via broadcasting and a # representation tree which this operator doesn't support. return self._matmul(other) From 7f803e3dcec5746f28b8942fee25c2f7e53ae5fb Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 27 Jul 2023 18:16:47 +0000 Subject: [PATCH 21/23] Hopefully fix weird CI errors --- linear_operator/operators/block_matrix_linear_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index 2ca60ed0..b6cd1cd1 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -5,7 +5,7 @@ from torch import Tensor from .. import settings -from ._linear_operator import IndexType, LinearOperator +from ._linear_operator import IndexType, LinearOperator, to_dense from .dense_linear_operator import DenseLinearOperator from .zero_linear_operator import ZeroLinearOperator @@ -94,7 +94,7 @@ def _matmul( # Failover implementation. Convert to dense and multiply matricies A = self.to_dense() - B = rhs.to_dense() + B = to_dense(rhs) # Batch logic is not supported for now assert B.ndim <= 2 From c7d094dd43b689879a3961e89da7925fd77a25f9 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 27 Jul 2023 19:16:04 +0000 Subject: [PATCH 22/23] Refactor BlockMatrixLO._matmul to better adhere to type signatures --- .../operators/block_matrix_linear_operator.py | 48 ++++++++++++------- .../test_block_matrix_linear_operator.py | 6 +-- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index b6cd1cd1..20e33b3e 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -60,25 +60,30 @@ def create_square_ops_output(T: int) -> List[List[LinearOperator]]: ops.append(tmp) return ops + def _matmul_two_block_matrix_linear_operators( + self: "BlockMatrixLinearOperator", + other: "BlockMatrixLinearOperator", + ) -> "BlockMatrixLinearOperator": + assert self.num_tasks == other.num_tasks + assert self.block_cols == other.block_rows + + T = self.num_tasks + output = BlockMatrixLinearOperator.create_square_ops_output(T) + for i in range(T): + for j in range(T): + out_ij = self.linear_operators[i][0] @ other.linear_operators[0][j] + for k in range(1, T): + out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] + output[i][j] = out_ij + return self.__class__(output) + def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: - T = self.num_tasks - # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts - # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() - if isinstance(rhs, self.__class__) and self.num_tasks == rhs.num_tasks and self.block_cols == rhs.block_rows: - output = BlockMatrixLinearOperator.create_square_ops_output(T) - for i in range(T): - for j in range(T): - out_ij = self.linear_operators[i][0] @ rhs.linear_operators[0][j] - for k in range(1, T): - out_ij += self.linear_operators[i][k] @ rhs.linear_operators[k][j] - output[i][j] = out_ij - return self.__class__(output) - elif isinstance(rhs, Tensor) and rhs.ndim == 2: + if isinstance(rhs, Tensor) and rhs.ndim == 2: # Check both matrix dims divisible by T, # reshape to (T, T, ), call block multiplication if rhs.size(0) % T == 0 and rhs.size(1) % T == 0: @@ -90,15 +95,14 @@ def _matmul( rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) rhs_op = BlockMatrixLinearOperator.from_tensor(rhs_blocks, T) - return self._matmul(rhs_op) + return self._matmul_two_block_matrix_linear_operators(rhs_op).to_dense() # Failover implementation. Convert to dense and multiply matricies + # Batch logic is not supported for now + assert rhs.dim() <= 2 A = self.to_dense() B = to_dense(rhs) - # Batch logic is not supported for now - assert B.ndim <= 2 - res = A @ B return res @@ -106,6 +110,16 @@ def matmul( self: Float[LinearOperator, "*batch M N"], other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], ) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: + # A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts + # ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() + if isinstance(other, self.__class__): + if self.num_tasks == other.num_tasks and self.block_cols == other.block_rows: + return self._matmul_two_block_matrix_linear_operators(other) + elif isinstance(other, LinearOperator): + from .matmul_linear_operator import MatmulLinearOperator + + return MatmulLinearOperator(self, other) + # The base method wants to perform a matmul via broadcasting and a # representation tree which this operator doesn't support. return self._matmul(other) diff --git a/test/operators/test_block_matrix_linear_operator.py b/test/operators/test_block_matrix_linear_operator.py index ca8cdc16..9a7da202 100644 --- a/test/operators/test_block_matrix_linear_operator.py +++ b/test/operators/test_block_matrix_linear_operator.py @@ -28,7 +28,7 @@ def test_multiply(self): A_blo = BlockMatrixLinearOperator.from_tensor(A, T) B_blo = BlockMatrixLinearOperator.from_tensor(B, T) - res_AB = A_blo._matmul(B_blo) + res_AB = A_blo.matmul(B_blo) res_dense_AB = res_AB.to_dense() A_dense = A.permute(0, 2, 1, 3).reshape(T * N, T * M) @@ -43,7 +43,7 @@ def test_multiply(self): self.assertAllClose(A, A_blocks_est) # Check Tensor multiplication - res_tensor_AB = A_blo._matmul(B_dense) + res_tensor_AB = A_blo.matmul(B_dense) res_tensor_dense_AB = res_tensor_AB.to_dense() self.assertAllClose(res_dense_AB, res_tensor_dense_AB) @@ -73,7 +73,7 @@ def test_sparse_multiply(self): # A_blo will contain dense operators along the diagonal + Zero operators off diagonal A_blo = BlockMatrixLinearOperator.from_tensor(A, T) B_blo = BlockMatrixLinearOperator.from_tensor(B, T) - res_AB = A_blo._matmul(B_blo) + res_AB = A_blo.matmul(B_blo) res_dense_AB = res_AB.to_dense() expected = A_dense @ B_dense From 8d29dd79dc47eb88946b183a9894e339c5429598 Mon Sep 17 00:00:00 2001 From: Geoff Pleiss <824157+gpleiss@users.noreply.github.com> Date: Thu, 27 Jul 2023 19:17:38 +0000 Subject: [PATCH 23/23] BlockMatrixLO takes in a flattened represetation --- .../operators/block_matrix_linear_operator.py | 66 +++++++------------ 1 file changed, 22 insertions(+), 44 deletions(-) diff --git a/linear_operator/operators/block_matrix_linear_operator.py b/linear_operator/operators/block_matrix_linear_operator.py index 20e33b3e..9aba135a 100644 --- a/linear_operator/operators/block_matrix_linear_operator.py +++ b/linear_operator/operators/block_matrix_linear_operator.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple, Union +import math +from typing import List, Optional, Union import torch from jaxtyping import Float @@ -20,34 +21,24 @@ class BlockMatrixLinearOperator(LinearOperator): represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as the appropriate linear operators on the blocks. - :param linear_operators: A TxT nested list of linear operators representing a 2-D matrix + :param linear_operators: A T^2 (flattened) list of linear operators representing a 2-D TxT block matrix. + The list of linear operators should be flattened into a concatenation of block-rowsa. """ - def __init__(self, linear_operators: List[List[LinearOperator]]) -> None: + def __init__(self, *flattened_linear_operators: LinearOperator) -> None: + self.num_tasks = int(math.sqrt(len(flattened_linear_operators))) + if settings.debug.on(): - assert hasattr( - linear_operators, "__iter__" - ), f"{self.__class__.__name__} expects a nested list (or iterable) of LinearOperators" - assert len(linear_operators) > 0, "must have non-empty list" - assert len(linear_operators[0]) == len(linear_operators), "must be square over block dimensions" + assert len(flattened_linear_operators) > 0, "must have non-empty list" + assert self.num_tasks**2 == len(flattened_linear_operators) - super().__init__(linear_operators) + super().__init__(*flattened_linear_operators) - self.linear_operators = linear_operators - self.num_tasks = len(self.linear_operators) - self.block_rows = linear_operators[0][0].shape[0] - self.block_cols = linear_operators[0][0].shape[1] - - # Check that provided operators all have the same shape - T = self.num_tasks - for i in range(T): - for j in range(T): - assert ( - linear_operators[i][j].shape[0] == self.block_rows - ), "the number of rows much match for all linear operators" - assert ( - linear_operators[i][j].shape[1] == self.block_cols - ), "the number of columns much match for all linear operators" + self.linear_operators = tuple( + flattened_linear_operators[i * self.num_tasks : (i + 1) * self.num_tasks] for i in range(self.num_tasks) + ) + self.block_rows = self.linear_operators[0][0].shape[0] + self.block_cols = self.linear_operators[0][0].shape[1] @staticmethod def create_square_ops_output(T: int) -> List[List[LinearOperator]]: @@ -68,14 +59,14 @@ def _matmul_two_block_matrix_linear_operators( assert self.block_cols == other.block_rows T = self.num_tasks - output = BlockMatrixLinearOperator.create_square_ops_output(T) + output = [] for i in range(T): for j in range(T): out_ij = self.linear_operators[i][0] @ other.linear_operators[0][j] for k in range(1, T): out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] - output[i][j] = out_ij - return self.__class__(output) + output.append(out_ij) + return self.__class__(*output) def _matmul( self: Float[LinearOperator, "*batch M N"], @@ -145,16 +136,6 @@ def dtype(self) -> Optional[torch.dtype]: def device(self) -> Optional[torch.device]: return self.linear_operators[0][0].device - def representation(self) -> Tuple[torch.Tensor, ...]: - """ - Returns the Tensors that are used to define the LinearOperator - """ - representation = [] - for op_row in self.linear_operators: - for op in op_row: - representation += tuple(op.representation()) - return tuple(representation) - def _diag(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: out = [] for i in range(self.num_tasks): @@ -166,11 +147,9 @@ def _diag(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N" def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: out = [] for i in range(self.num_tasks): - rows = [] for j in range(self.num_tasks): - rows.append(self.linear_operators[j][i].mT) - out.append(rows) - return BlockMatrixLinearOperator(out) + out.append(self.linear_operators[j][i].mT) + return BlockMatrixLinearOperator(*out) def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: # Perform the __getitem__ @@ -186,7 +165,6 @@ def tensor_to_linear_op(t: Tensor) -> LinearOperator: return ZeroLinearOperator(*t.size(), dtype=t.dtype, device=t.device) linear_ops = [ - [tensor_to_linear_op(t[0]) for t in list(torch.tensor_split(tensor[i], num_tasks))] - for i in range(num_tasks) + tensor_to_linear_op(t[0]) for i in range(num_tasks) for t in list(torch.tensor_split(tensor[i], num_tasks)) ] - return cls(linear_ops) + return cls(*linear_ops)