-
Notifications
You must be signed in to change notification settings - Fork 30
Block Matrix Linear Operator #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
corwinjoy
wants to merge
26
commits into
cornellius-gp:main
Choose a base branch
from
corwinjoy:block_tensor_lo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
e2e32f1
Rebase vs. main
15effdf
Fix block tensor type signatures
4055a7f
Get simple test running for BlockTensor
2cf2db0
Add simple property implementations
61d9b8a
Upgrade linear operator to add block / sparse test
bbc12ea
Add and document core test cases
bfc843a
Cleanup dead comments
516b369
Update linear_operator/operators/block_tensor_linear_operator.py
corwinjoy 9d12d7c
Update linear_operator/operators/block_tensor_linear_operator.py
corwinjoy 56809b2
Improve construction tests and types
2f9b4b7
Rename class to MatrixLinearOperator
7cbbc54
Add parts omitted from base test case. Show them as commented out to …
32e9a52
Rename class to BlockMatrixLinearOperator
30ad0ed
Fix type signature
889ce0f
Improve comments
d15f368
Merge branch 'main' into block_tensor_lo
corwinjoy c557aa3
Refactor linear_operator_test_case.py into a set of core tests and mo…
d2cb1cc
Merge remote-tracking branch 'origin/block_tensor_lo' into block_tens…
67b8fe9
Incorporate review suggestions from Geoff Pleiss.
75b565a
Add comment explaining matmul override.
gpleiss ff0b6a2
Add jaxtyping requirement for conda
gpleiss ffc3116
Merge branch 'main' into block_tensor_lo
gpleiss 58e8686
Fix linter
gpleiss 7f803e3
Hopefully fix weird CI errors
gpleiss c7d094d
Refactor BlockMatrixLO._matmul to better adhere to type signatures
gpleiss 8d29dd7
BlockMatrixLO takes in a flattened represetation
gpleiss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ requirements: | |
run: | ||
- pytorch>=1.11 | ||
- scipy | ||
- jaxtyping>=0.2.9 | ||
- typeguard~=2.13.3 | ||
|
||
test: | ||
imports: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
linear_operator/operators/block_matrix_linear_operator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import math | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from jaxtyping import Float | ||
from torch import Tensor | ||
|
||
from .. import settings | ||
from ._linear_operator import IndexType, LinearOperator, to_dense | ||
from .dense_linear_operator import DenseLinearOperator | ||
from .zero_linear_operator import ZeroLinearOperator | ||
|
||
|
||
class BlockMatrixLinearOperator(LinearOperator): | ||
""" | ||
A TxT block matrix of LinearOperators. | ||
|
||
Idea. Represent [TN, TM] tensors by TxT blocks of NxM lazy tensors. | ||
|
||
Implementation. A block linear operator class that can keep track of the [T, T] block structure, | ||
represented as T^2 lazy tensors of the same shape. Implement matrix multiplication between block matrices as | ||
the appropriate linear operators on the blocks. | ||
|
||
:param linear_operators: A T^2 (flattened) list of linear operators representing a 2-D TxT block matrix. | ||
The list of linear operators should be flattened into a concatenation of block-rowsa. | ||
""" | ||
|
||
def __init__(self, *flattened_linear_operators: LinearOperator) -> None: | ||
self.num_tasks = int(math.sqrt(len(flattened_linear_operators))) | ||
|
||
if settings.debug.on(): | ||
assert len(flattened_linear_operators) > 0, "must have non-empty list" | ||
assert self.num_tasks**2 == len(flattened_linear_operators) | ||
|
||
super().__init__(*flattened_linear_operators) | ||
|
||
self.linear_operators = tuple( | ||
flattened_linear_operators[i * self.num_tasks : (i + 1) * self.num_tasks] for i in range(self.num_tasks) | ||
) | ||
self.block_rows = self.linear_operators[0][0].shape[0] | ||
self.block_cols = self.linear_operators[0][0].shape[1] | ||
|
||
@staticmethod | ||
def create_square_ops_output(T: int) -> List[List[LinearOperator]]: | ||
"""Return an empty (square) list of operators of shape TxT""" | ||
ops = [] | ||
for i in range(T): | ||
tmp = [] | ||
for j in range(T): | ||
tmp.append([]) | ||
ops.append(tmp) | ||
return ops | ||
|
||
def _matmul_two_block_matrix_linear_operators( | ||
self: "BlockMatrixLinearOperator", | ||
other: "BlockMatrixLinearOperator", | ||
) -> "BlockMatrixLinearOperator": | ||
assert self.num_tasks == other.num_tasks | ||
assert self.block_cols == other.block_rows | ||
|
||
T = self.num_tasks | ||
output = [] | ||
for i in range(T): | ||
for j in range(T): | ||
out_ij = self.linear_operators[i][0] @ other.linear_operators[0][j] | ||
for k in range(1, T): | ||
out_ij += self.linear_operators[i][k] @ other.linear_operators[k][j] | ||
output.append(out_ij) | ||
return self.__class__(*output) | ||
|
||
def _matmul( | ||
self: Float[LinearOperator, "*batch M N"], | ||
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], | ||
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: | ||
T = self.num_tasks | ||
|
||
if isinstance(rhs, Tensor) and rhs.ndim == 2: | ||
# Check both matrix dims divisible by T, | ||
# reshape to (T, T, ), call block multiplication | ||
if rhs.size(0) % T == 0 and rhs.size(1) % T == 0: | ||
# A is block [N * T, M * T] and B is a general tensor/operator of shape [O, P]. | ||
# If O and P are both divisible by T, | ||
# then interpret B as a [O//T * T, P//T * T] block matrix | ||
O_T = rhs.size(0) // T | ||
P_T = rhs.size(1) // T | ||
rhs_blocks_raw = rhs.reshape(T, O_T, T, P_T) | ||
rhs_blocks = rhs_blocks_raw.permute(0, 2, 1, 3) | ||
rhs_op = BlockMatrixLinearOperator.from_tensor(rhs_blocks, T) | ||
return self._matmul_two_block_matrix_linear_operators(rhs_op).to_dense() | ||
|
||
# Failover implementation. Convert to dense and multiply matricies | ||
# Batch logic is not supported for now | ||
assert rhs.dim() <= 2 | ||
A = self.to_dense() | ||
B = to_dense(rhs) | ||
|
||
res = A @ B | ||
return res | ||
|
||
def matmul( | ||
self: Float[LinearOperator, "*batch M N"], | ||
other: Union[Float[Tensor, "*batch2 N P"], Float[Tensor, "*batch2 N"], Float[LinearOperator, "*batch2 N P"]], | ||
) -> Union[Float[Tensor, "... M P"], Float[Tensor, "... M"], Float[LinearOperator, "... M P"]]: | ||
# A is block [N * T1, M * T2] and B is block [O * S1, P * S2]. If A and B have conformal block counts | ||
# ie T2==S1 as well as M==O then use the blockwise algorithm. Else use to_dense() | ||
if isinstance(other, self.__class__): | ||
if self.num_tasks == other.num_tasks and self.block_cols == other.block_rows: | ||
return self._matmul_two_block_matrix_linear_operators(other) | ||
elif isinstance(other, LinearOperator): | ||
from .matmul_linear_operator import MatmulLinearOperator | ||
|
||
return MatmulLinearOperator(self, other) | ||
|
||
# The base method wants to perform a matmul via broadcasting and a | ||
# representation tree which this operator doesn't support. | ||
return self._matmul(other) | ||
gpleiss marked this conversation as resolved.
Show resolved
Hide resolved
gpleiss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: | ||
out = [] | ||
for i in range(self.num_tasks): | ||
rows = [] | ||
for j in range(self.num_tasks): | ||
rows.append(self.linear_operators[i][j].to_dense()) | ||
out.append(torch.concat(rows, axis=1)) | ||
return torch.concat(out, axis=0) | ||
|
||
def _size(self) -> torch.Size: | ||
sz = self.linear_operators[0][0].size() | ||
return torch.Size([self.num_tasks * sz[0], self.num_tasks * sz[1]]) | ||
|
||
@property | ||
def dtype(self) -> Optional[torch.dtype]: | ||
return self.linear_operators[0][0].dtype | ||
|
||
@property | ||
def device(self) -> Optional[torch.device]: | ||
return self.linear_operators[0][0].device | ||
|
||
def _diag(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: | ||
out = [] | ||
for i in range(self.num_tasks): | ||
# The underlying operators will test if they are square | ||
diagonal = self.linear_operators[i][i].diagonal() | ||
out.append(diagonal) | ||
return torch.concat(out, axis=1) | ||
|
||
def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: | ||
out = [] | ||
for i in range(self.num_tasks): | ||
for j in range(self.num_tasks): | ||
out.append(self.linear_operators[j][i].mT) | ||
return BlockMatrixLinearOperator(*out) | ||
|
||
def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: | ||
# Perform the __getitem__ | ||
tsr = self.to_dense() | ||
res = tsr[(*batch_indices, row_index, col_index)] | ||
return DenseLinearOperator(res) | ||
|
||
@classmethod | ||
def from_tensor(cls, tensor: Tensor, num_tasks: int) -> "BlockMatrixLinearOperator": | ||
def tensor_to_linear_op(t: Tensor) -> LinearOperator: | ||
if torch.count_nonzero(t) > 0: | ||
return DenseLinearOperator(t) | ||
return ZeroLinearOperator(*t.size(), dtype=t.dtype, device=t.device) | ||
|
||
linear_ops = [ | ||
tensor_to_linear_op(t[0]) for i in range(num_tasks) for t in list(torch.tensor_split(tensor[i], num_tasks)) | ||
] | ||
return cls(*linear_ops) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.