Add LAPACK overloads for all variants of pt.linalg.solve
#1146
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Goal of this PR is to give numba mode full coverage of
scipy.linalg.solve
options. We currently only supportassume_a = "gen"
. If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:overwrite_a
in numba modeoverwrite_b
in numba modetransposed
argument (all modes)lu_factor
andlu_solve
Ops
(all modes)assume_a = "sym"
andassume_a = "pos"
in numba modecho_solve
in numba modeWe get the
lu_factor
andlu_solve
Ops kind of "for free" because I'm adding overloads fordgetrs
anddgetrf
. We just have to write the Ops and do the JAX dispatch. JVP forlu_factor
is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.Related Issue
pt.linalg.solve
returns incorrect results whenmode = "NUMBA"
#422Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/