Skip to content

Commit cace929

Browse files
Implement numba solve for assume_a = "gen"
1 parent 4b7b008 commit cace929

File tree

5 files changed

+734
-237
lines changed

5 files changed

+734
-237
lines changed
+264
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import ctypes
2+
3+
import numpy as np
4+
from numba.core import cgutils, types
5+
from numba.core.extending import get_cython_function_address, intrinsic
6+
from numba.np.linalg import ensure_lapack, get_blas_kind
7+
8+
9+
_PTR = ctypes.POINTER
10+
11+
_dbl = ctypes.c_double
12+
_float = ctypes.c_float
13+
_char = ctypes.c_char
14+
_int = ctypes.c_int
15+
16+
_ptr_float = _PTR(_float)
17+
_ptr_dbl = _PTR(_dbl)
18+
_ptr_char = _PTR(_char)
19+
_ptr_int = _PTR(_int)
20+
21+
22+
def _get_lapack_ptr_and_ptr_type(dtype, name):
23+
d = get_blas_kind(dtype)
24+
func_name = f"{d}{name}"
25+
float_pointer = _get_float_pointer_for_dtype(d)
26+
lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name)
27+
28+
return lapack_ptr, float_pointer
29+
30+
31+
def _get_underlying_float(dtype):
32+
s_dtype = str(dtype)
33+
out_type = s_dtype
34+
if s_dtype == "complex64":
35+
out_type = "float32"
36+
elif s_dtype == "complex128":
37+
out_type = "float64"
38+
39+
return np.dtype(out_type)
40+
41+
42+
def _get_float_pointer_for_dtype(blas_dtype):
43+
if blas_dtype in ["s", "c"]:
44+
return _ptr_float
45+
elif blas_dtype in ["d", "z"]:
46+
return _ptr_dbl
47+
48+
49+
def _get_output_ctype(dtype):
50+
s_dtype = str(dtype)
51+
if s_dtype in ["float32", "complex64"]:
52+
return _float
53+
elif s_dtype in ["float64", "complex128"]:
54+
return _dbl
55+
56+
57+
@intrinsic
58+
def sptr_to_val(typingctx, data):
59+
def impl(context, builder, signature, args):
60+
val = builder.load(args[0])
61+
return val
62+
63+
sig = types.float32(types.CPointer(types.float32))
64+
return sig, impl
65+
66+
67+
@intrinsic
68+
def dptr_to_val(typingctx, data):
69+
def impl(context, builder, signature, args):
70+
val = builder.load(args[0])
71+
return val
72+
73+
sig = types.float64(types.CPointer(types.float64))
74+
return sig, impl
75+
76+
77+
@intrinsic
78+
def int_ptr_to_val(typingctx, data):
79+
def impl(context, builder, signature, args):
80+
val = builder.load(args[0])
81+
return val
82+
83+
sig = types.int32(types.CPointer(types.int32))
84+
return sig, impl
85+
86+
87+
@intrinsic
88+
def val_to_int_ptr(typingctx, data):
89+
def impl(context, builder, signature, args):
90+
ptr = cgutils.alloca_once_value(builder, args[0])
91+
return ptr
92+
93+
sig = types.CPointer(types.int32)(types.int32)
94+
return sig, impl
95+
96+
97+
@intrinsic
98+
def val_to_sptr(typingctx, data):
99+
def impl(context, builder, signature, args):
100+
ptr = cgutils.alloca_once_value(builder, args[0])
101+
return ptr
102+
103+
sig = types.CPointer(types.float32)(types.float32)
104+
return sig, impl
105+
106+
107+
@intrinsic
108+
def val_to_zptr(typingctx, data):
109+
def impl(context, builder, signature, args):
110+
ptr = cgutils.alloca_once_value(builder, args[0])
111+
return ptr
112+
113+
sig = types.CPointer(types.complex128)(types.complex128)
114+
return sig, impl
115+
116+
117+
@intrinsic
118+
def val_to_dptr(typingctx, data):
119+
def impl(context, builder, signature, args):
120+
ptr = cgutils.alloca_once_value(builder, args[0])
121+
return ptr
122+
123+
sig = types.CPointer(types.float64)(types.float64)
124+
return sig, impl
125+
126+
127+
class _LAPACK:
128+
"""
129+
Functions to return type signatures for wrapped LAPACK functions.
130+
131+
Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74
132+
"""
133+
134+
def __init__(self):
135+
ensure_lapack()
136+
137+
@classmethod
138+
def numba_xtrtrs(cls, dtype):
139+
"""
140+
Solve a triangular system of equations of the form A @ X = B or A.T @ X = B.
141+
142+
Called by scipy.linalg.solve_triangular
143+
"""
144+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs")
145+
146+
functype = ctypes.CFUNCTYPE(
147+
None,
148+
_ptr_int, # UPLO
149+
_ptr_int, # TRANS
150+
_ptr_int, # DIAG
151+
_ptr_int, # N
152+
_ptr_int, # NRHS
153+
float_pointer, # A
154+
_ptr_int, # LDA
155+
float_pointer, # B
156+
_ptr_int, # LDB
157+
_ptr_int, # INFO
158+
)
159+
160+
return functype(lapack_ptr)
161+
162+
@classmethod
163+
def numba_xpotrf(cls, dtype):
164+
"""
165+
Compute the Cholesky factorization of a real symmetric positive definite matrix.
166+
167+
Called by scipy.linalg.cholesky
168+
"""
169+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
170+
functype = ctypes.CFUNCTYPE(
171+
None,
172+
_ptr_int, # UPLO,
173+
_ptr_int, # N
174+
float_pointer, # A
175+
_ptr_int, # LDA
176+
_ptr_int, # INFO
177+
)
178+
return functype(lapack_ptr)
179+
180+
@classmethod
181+
def numba_xlange(cls, dtype):
182+
"""
183+
Compute the value of the 1-norm, Frobenius norm, infinity-norm, or the largest absolute value of any element of
184+
a general M-by-N matrix A.
185+
186+
Called by scipy.linalg.solve
187+
"""
188+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lange")
189+
output_ctype = _get_output_ctype(dtype)
190+
functype = ctypes.CFUNCTYPE(
191+
output_ctype, # Output
192+
_ptr_int, # NORM
193+
_ptr_int, # M
194+
_ptr_int, # N
195+
float_pointer, # A
196+
_ptr_int, # LDA
197+
float_pointer, # WORK
198+
)
199+
return functype(lapack_ptr)
200+
201+
@classmethod
202+
def numba_xgecon(cls, dtype):
203+
"""
204+
Estimates the condition number of a matrix A, using the LU factorization computed by numba_getrf.
205+
206+
Called by scipy.linalg.solve when assume_a == "gen"
207+
"""
208+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "gecon")
209+
functype = ctypes.CFUNCTYPE(
210+
None,
211+
_ptr_int, # NORM
212+
_ptr_int, # N
213+
float_pointer, # A
214+
_ptr_int, # LDA
215+
float_pointer, # ANORM
216+
float_pointer, # RCOND
217+
float_pointer, # WORK
218+
_ptr_int, # IWORK
219+
_ptr_int, # INFO
220+
)
221+
return functype(lapack_ptr)
222+
223+
@classmethod
224+
def numba_xgetrf(cls, dtype):
225+
"""
226+
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
227+
228+
Called by scipy.linalg.solve when assume_a == "gen"
229+
"""
230+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
231+
functype = ctypes.CFUNCTYPE(
232+
None,
233+
_ptr_int, # M
234+
_ptr_int, # N
235+
float_pointer, # A
236+
_ptr_int, # LDA
237+
_ptr_int, # IPIV
238+
_ptr_int, # INFO
239+
)
240+
return functype(lapack_ptr)
241+
242+
@classmethod
243+
def numba_xgetrs(cls, dtype):
244+
"""
245+
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
246+
factorization computed by numba_getrf.
247+
248+
Called by scipy.linalg.solve when assume_a == "gen"
249+
"""
250+
...
251+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
252+
functype = ctypes.CFUNCTYPE(
253+
None,
254+
_ptr_int, # TRANS
255+
_ptr_int, # N
256+
_ptr_int, # NRHS
257+
float_pointer, # A
258+
_ptr_int, # LDA
259+
_ptr_int, # IPIV
260+
float_pointer, # B
261+
_ptr_int, # LDB
262+
_ptr_int, # INFO
263+
)
264+
return functype(lapack_ptr)

0 commit comments

Comments
 (0)