Skip to content

Commit 7eb82c2

Browse files
committed
Add dgetri_
1 parent a5bf6ee commit 7eb82c2

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

Sources/AccelerateLinux/MatrixOps/LAPACK.swift

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public func dgesv_(
2323
_ __info: UnsafeMutablePointer<__CLPK_integer>!
2424
)
2525

26-
/// DGESVD computes the singular value decomposition (SVD) of a real
26+
/// Computes the singular value decomposition (SVD) of a real
2727
/// M-by-N matrix A, optionally computing the left and/or right singular
2828
/// vectors. The SVD is written
2929
///
@@ -55,7 +55,7 @@ public func dgesvd_(
5555
_ __info: UnsafeMutablePointer<__CLPK_integer>!
5656
)
5757
58-
/// DGETRF computes an LU factorization of a general M-by-N matrix A
58+
/// Computes an LU factorization of a general M-by-N matrix A
5959
/// using partial pivoting with row interchanges.
6060
///
6161
/// The factorization has the form
@@ -74,4 +74,20 @@ public func dgetrf_(
7474
_ __ipiv: UnsafeMutablePointer<__CLPK_integer>!,
7575
_ __info: UnsafeMutablePointer<__CLPK_integer>!
7676
) -> Int32
77+
78+
/// Computes the inverse of a matrix using the LU factorization
79+
/// computed by ``dgterf_``.
80+
///
81+
/// This method inverts U and then computes `inv(A)` by solving the system
82+
/// `inv(A)*L = inv(U) for inv(A)`.
83+
@_silgen_name("dgetri_")
84+
public func dgetri_(
85+
_ __n: UnsafeMutablePointer<__CLPK_integer>!,
86+
_ __a: UnsafeMutablePointer<__CLPK_doublereal>!,
87+
_ __lda: UnsafeMutablePointer<__CLPK_integer>!,
88+
_ __ipiv: UnsafeMutablePointer<__CLPK_integer>!,
89+
_ __work: UnsafeMutablePointer<__CLPK_doublereal>!,
90+
_ __lwork: UnsafeMutablePointer<__CLPK_integer>!,
91+
_ __info: UnsafeMutablePointer<__CLPK_integer>!
92+
) -> Int32
7793
#endif // canImport(Accelerate)

Tests/AccelerateLinuxTests/MatrixTests/LAPACKTests.swift

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,45 @@ struct LAPACKTests {
148148
]
149149
)
150150
}
151+
152+
// https://numericalalgorithmsgroup.github.io/LAPACK_Examples/examples/doc/dgetri_example.html
153+
@Test("dgetri_")
154+
func test_dgetri_() {
155+
let M = 4
156+
let N = 4
157+
let LDA = M
158+
159+
var m = __CLPK_integer(M)
160+
var n = __CLPK_integer(N)
161+
var lda = __CLPK_integer(LDA)
162+
var info = __CLPK_integer(0)
163+
164+
var ipiv = [__CLPK_integer](repeating: 0, count: Int(min(m, n)))
165+
166+
var a: [__CLPK_doublereal] = [
167+
1.80, 5.25, 1.58, -1.11,
168+
2.88, -2.95, -2.69, -0.66,
169+
2.05, -0.95, -2.90, -0.59,
170+
-0.89, -3.80, -1.04, 0.80,
171+
]
172+
173+
// A is nonsymmetric and must be factorized first
174+
_ = dgetrf_(&m, &n, &a, &lda, &ipiv, &info)
175+
176+
var work = [__CLPK_doublereal](repeating: 0.0, count: Int(m))
177+
var lwork = __CLPK_integer(64 * n)
178+
179+
_ = dgetri_(&n, &a, &lda, &ipiv, &work, &lwork, &info)
180+
181+
#expect(
182+
a.map {
183+
($0 * pow(10, 2)).rounded() / pow(10, 2)
184+
} == [
185+
1.77, -0.12, 0.18, 2.49,
186+
0.58, -0.45, 0.45, 0.76,
187+
0.08, 0.41, -0.67, -0.04,
188+
4.82, -1.71, 1.48, 7.61,
189+
]
190+
)
191+
}
151192
}

0 commit comments

Comments
 (0)