Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement linear algebra functions in PyTorch #922

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

twaclaw
Copy link
Contributor

@twaclaw twaclaw commented Jul 11, 2024

Description

Implemented:

  • BlockDiagonal
  • Cholesky
  • Eigvalsh
  • Solve
  • SolveTriangular

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented Jul 11, 2024

Codecov Report

Attention: Patch coverage is 66.66667% with 17 lines in your changes missing coverage. Please review.

Project coverage is 81.73%. Comparing base (117f80d) to head (4067a87).

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/slinalg.py 66.00% 16 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #922      +/-   ##
==========================================
- Coverage   81.74%   81.73%   -0.02%     
==========================================
  Files         183      184       +1     
  Lines       47740    47791      +51     
  Branches    11616    11623       +7     
==========================================
+ Hits        39027    39061      +34     
- Misses       6520     6534      +14     
- Partials     2193     2196       +3     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/slinalg.py 66.00% <66.00%> (ø)

... and 2 files with indirect coverage changes

@twaclaw twaclaw force-pushed the implement_slinalg_ops_pytorch branch from 05118b5 to 2fb9f0e Compare July 13, 2024 17:16
@ricardoV94 ricardoV94 added enhancement New feature or request linalg Linear algebra torch PyTorch backend labels Jul 18, 2024

def solve(a, b):
if lower:
return torch.linalg.solve(torch.tril(a), b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do this. At worst it will introduce silent bugs in user code (because solve doesn't do any type checking w.r.t the keywords and args), at best it introduces an extra tril Op for nothing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.solve doesn't have any keywords, so I suggest just directly returning the base case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.solve is deprecated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean torch.linalg.solve, which also has none of the scipy solver keywords (most importantly assume_a)

trans = op.trans

def solve_triangular(A, b):
A_p = A
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use if ... elif .. else here


def solve_triangular(A, b):
A_p = A
if trans == 1 or trans == "T":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if trans == 1 or trans == "T":
if trans in [1, "T"]:

if trans == 1 or trans == "T":
A_p = A.T

if trans == 2 or trans == "C":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if trans == 2 or trans == "C":
if trans in [2, "C"]:

A_p = A.T

if trans == 2 or trans == "C":
A_p = A.H
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure we don't have .H, need to do .conj().T

tests/link/pytorch/test_slinalg.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_slinalg.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/slinalg.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/slinalg.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_slinalg.py Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Implemented linear algebra functions in PyTorch Implement linear algebra functions in PyTorch Sep 1, 2024
@ricardoV94
Copy link
Member

ricardoV94 commented Oct 3, 2024

@twaclaw can you rebase the PR to solve the conflicts. And if you addressed @jessegrabowski feel free to mark the comments as resolved for us to see if there's anything left to be addressed.

Let us know if you are not available. And thanks in advance!

@twaclaw
Copy link
Contributor Author

twaclaw commented Oct 3, 2024

@twaclaw can you rebase the PR to solve the conflicts. And if you addressed @jessegrabowski feel free to mark the comments as resolved for us to see if there's anything left to be addressed.

Let us know if you are not available. And thanks in advance!

Will do shortly, I was indeed not available for the last month or so but I am back now.

twaclaw and others added 6 commits October 3, 2024 17:18
- BlockDiagonal
- Cholesky
- Eigvalsh
- Solve
- SolveTriangular
Added support for a b param of ndim = 1.
- Removed lower param to solve
- Refactored tests
@twaclaw twaclaw force-pushed the implement_slinalg_ops_pytorch branch from 76e1471 to 4067a87 Compare October 3, 2024 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request linalg Linear algebra torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants