1+ from typing import Literal
2+
13import numpy as np
24from numba .core .extending import overload
35from numba .np .linalg import _copy_to_fortran_order , ensure_lapack
1315
1416def _xgeqrf (A : np .ndarray , overwrite_a : bool , lwork : int ):
1517 """LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
16- (geqrf ,) = get_lapack_funcs (("geqrf" ,), (A ,))
18+ # (geqrf,) = typing_cast(
19+ # list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
20+ # )
21+ funcs = get_lapack_funcs (("geqrf" ,), (A ,))
22+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
23+ geqrf = funcs [0 ]
24+
1725 return geqrf (A , overwrite_a = overwrite_a , lwork = lwork )
1826
1927
@@ -61,7 +69,10 @@ def impl(A, overwrite_a, lwork):
6169
6270def _xgeqp3 (A : np .ndarray , overwrite_a : bool , lwork : int ):
6371 """LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
64- (geqp3 ,) = get_lapack_funcs (("geqp3" ,), (A ,))
72+ funcs = get_lapack_funcs (("geqp3" ,), (A ,))
73+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
74+ geqp3 = funcs [0 ]
75+
6576 return geqp3 (A , overwrite_a = overwrite_a , lwork = lwork )
6677
6778
@@ -111,7 +122,10 @@ def impl(A, overwrite_a, lwork):
111122
112123def _xorgqr (A : np .ndarray , tau : np .ndarray , overwrite_a : bool , lwork : int ):
113124 """LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
114- (orgqr ,) = get_lapack_funcs (("orgqr" ,), (A ,))
125+ funcs = get_lapack_funcs (("orgqr" ,), (A ,))
126+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
127+ orgqr = funcs [0 ]
128+
115129 return orgqr (A , tau , overwrite_a = overwrite_a , lwork = lwork )
116130
117131
@@ -160,7 +174,10 @@ def impl(A, tau, overwrite_a, lwork):
160174
161175def _xungqr (A : np .ndarray , tau : np .ndarray , overwrite_a : bool , lwork : int ):
162176 """LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
163- (ungqr ,) = get_lapack_funcs (("ungqr" ,), (A ,))
177+ funcs = get_lapack_funcs (("ungqr" ,), (A ,))
178+ assert isinstance (funcs , list ) # narrows `funcs: list[F] | F` to `funcs: list[F]`
179+ ungqr = funcs [0 ]
180+
164181 return ungqr (A , tau , overwrite_a = overwrite_a , lwork = lwork )
165182
166183
@@ -209,8 +226,8 @@ def impl(A, tau, overwrite_a, lwork):
209226
210227def _qr_full_pivot (
211228 x : np .ndarray ,
212- mode : str = "full" ,
213- pivoting : bool = True ,
229+ mode : Literal [ "full" , "economic" ] = "full" ,
230+ pivoting : Literal [ True ] = True ,
214231 overwrite_a : bool = False ,
215232 check_finite : bool = False ,
216233 lwork : int | None = None ,
@@ -234,8 +251,8 @@ def _qr_full_pivot(
234251
235252def _qr_full_no_pivot (
236253 x : np .ndarray ,
237- mode : str = "full" ,
238- pivoting : bool = False ,
254+ mode : Literal [ "full" , "economic" ] = "full" ,
255+ pivoting : Literal [ False ] = False ,
239256 overwrite_a : bool = False ,
240257 check_finite : bool = False ,
241258 lwork : int | None = None ,
@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
258275
259276def _qr_r_pivot (
260277 x : np .ndarray ,
261- mode : str = "r" ,
262- pivoting : bool = True ,
278+ mode : Literal [ "r" , "raw" ] = "r" ,
279+ pivoting : Literal [ True ] = True ,
263280 overwrite_a : bool = False ,
264281 check_finite : bool = False ,
265282 lwork : int | None = None ,
@@ -282,8 +299,8 @@ def _qr_r_pivot(
282299
283300def _qr_r_no_pivot (
284301 x : np .ndarray ,
285- mode : str = "r" ,
286- pivoting : bool = False ,
302+ mode : Literal [ "r" , "raw" ] = "r" ,
303+ pivoting : Literal [ False ] = False ,
287304 overwrite_a : bool = False ,
288305 check_finite : bool = False ,
289306 lwork : int | None = None ,
@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
306323
307324def _qr_raw_no_pivot (
308325 x : np .ndarray ,
309- mode : str = "raw" ,
310- pivoting : bool = False ,
326+ mode : Literal [ "raw" ] = "raw" ,
327+ pivoting : Literal [ False ] = False ,
311328 overwrite_a : bool = False ,
312329 check_finite : bool = False ,
313330 lwork : int | None = None ,
@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
332349
333350def _qr_raw_pivot (
334351 x : np .ndarray ,
335- mode : str = "raw" ,
336- pivoting : bool = True ,
352+ mode : Literal [ "raw" ] = "raw" ,
353+ pivoting : Literal [ True ] = True ,
337354 overwrite_a : bool = False ,
338355 check_finite : bool = False ,
339356 lwork : int | None = None ,
0 commit comments