-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement linear algebra functions in PyTorch #922
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #922 +/- ##
==========================================
- Coverage 81.74% 81.73% -0.02%
==========================================
Files 183 184 +1
Lines 47740 47791 +51
Branches 11616 11623 +7
==========================================
+ Hits 39027 39061 +34
- Misses 6520 6534 +14
- Partials 2193 2196 +3
|
05118b5
to
2fb9f0e
Compare
|
||
def solve(a, b): | ||
if lower: | ||
return torch.linalg.solve(torch.tril(a), b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to do this. At worst it will introduce silent bugs in user code (because solve
doesn't do any type checking w.r.t the keywords and args), at best it introduces an extra tril
Op
for nothing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.solve
doesn't have any keywords, so I suggest just directly returning the base case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.solve
is deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean torch.linalg.solve
, which also has none of the scipy solver keywords (most importantly assume_a
)
trans = op.trans | ||
|
||
def solve_triangular(A, b): | ||
A_p = A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can use if ... elif .. else
here
|
||
def solve_triangular(A, b): | ||
A_p = A | ||
if trans == 1 or trans == "T": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if trans == 1 or trans == "T": | |
if trans in [1, "T"]: |
if trans == 1 or trans == "T": | ||
A_p = A.T | ||
|
||
if trans == 2 or trans == "C": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if trans == 2 or trans == "C": | |
if trans in [2, "C"]: |
A_p = A.T | ||
|
||
if trans == 2 or trans == "C": | ||
A_p = A.H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty sure we don't have .H
, need to do .conj().T
@twaclaw can you rebase the PR to solve the conflicts. And if you addressed @jessegrabowski feel free to mark the comments as resolved for us to see if there's anything left to be addressed. Let us know if you are not available. And thanks in advance! |
Will do shortly, I was indeed not available for the last month or so but I am back now. |
- BlockDiagonal - Cholesky - Eigvalsh - Solve - SolveTriangular
Added support for a b param of ndim = 1.
- Removed lower param to solve - Refactored tests
Co-authored-by: Jesse Grabowski <[email protected]>
This reverts commit 0dc65e2.
76e1471
to
4067a87
Compare
Description
Implemented:
Related Issue
Checklist
Type of change