Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LAPACK overloads for all variants of pt.linalg.solve #1146

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Dec 31, 2024

Description

Goal of this PR is to give numba mode full coverage of scipy.linalg.solve options. We currently only support assume_a = "gen". If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:

  • support for overwrite_a in numba mode
  • support for overwrite_b in numba mode
  • support for transposed argument (all modes)
  • lu_factor and lu_solve Ops (all modes)
  • support for assume_a = "sym" and assume_a = "pos" in numba mode
  • support for cho_solve in numba mode

We get the lu_factor and lu_solve Ops kind of "for free" because I'm adding overloads for dgetrs and dgetrf. We just have to write the Ops and do the JAX dispatch. JVP for lu_factor is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/

@jessegrabowski jessegrabowski added bug Something isn't working enhancement New feature or request numba SciPy compatibility linalg Linear algebra labels Dec 31, 2024
@jessegrabowski jessegrabowski marked this pull request as ready for review December 31, 2024 15:24
@jessegrabowski jessegrabowski requested review from aseyboldt and ricardoV94 and removed request for aseyboldt December 31, 2024 15:24
Copy link

codecov bot commented Dec 31, 2024

Codecov Report

Attention: Patch coverage is 52.49042% with 248 lines in your changes missing coverage. Please review.

Project coverage is 81.87%. Comparing base (4e85676) to head (fe97e5d).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/slinalg.py 42.62% 205 Missing and 5 partials ⚠️
pytensor/link/numba/dispatch/_LAPACK.py 75.32% 32 Missing and 6 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1146      +/-   ##
==========================================
- Coverage   82.14%   81.87%   -0.28%     
==========================================
  Files         186      187       +1     
  Lines       48210    48617     +407     
  Branches     8678     8705      +27     
==========================================
+ Hits        39603    39804     +201     
- Misses       6440     6640     +200     
- Partials     2167     2173       +6     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 80.37% <ø> (-1.86%) ⬇️
pytensor/tensor/slinalg.py 93.49% <100.00%> (+0.01%) ⬆️
pytensor/link/numba/dispatch/_LAPACK.py 75.32% <75.32%> (ø)
pytensor/link/numba/dispatch/slinalg.py 44.42% <42.62%> (-8.38%) ⬇️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request linalg Linear algebra numba SciPy compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: pt.linalg.solve returns incorrect results when mode = "NUMBA"
1 participant