|
| 1 | +#if !ACCELERATE_NEW_LAPACK |
| 2 | +import AccelerateLinux |
| 3 | +import Foundation |
| 4 | +import Testing |
| 5 | + |
| 6 | +@Suite("CPLK Tests") |
| 7 | +struct CPLKTests { |
| 8 | + // https://www.intel.com/content/www/us/en/docs/onemkl/code-samples-lapack/2022-1/dgesv-example-c.html |
| 9 | + @Test("dgesv_") |
| 10 | + func test_dgesv_() { |
| 11 | + let N = 5 |
| 12 | + let NRHS = 3 |
| 13 | + let LDA = N |
| 14 | + let LDB = N |
| 15 | + |
| 16 | + var n = __CLPK_integer(N) |
| 17 | + var nrhs = __CLPK_integer(NRHS) |
| 18 | + var lda = __CLPK_integer(LDA) |
| 19 | + var ldb = __CLPK_integer(LDB) |
| 20 | + var info = __CLPK_integer(0) |
| 21 | + |
| 22 | + var ipiv = [__CLPK_integer](repeating: 0, count: Int(N)) |
| 23 | + var a: [__CLPK_doublereal] = [ |
| 24 | + 6.80, -2.11, 5.66, 5.97, 8.23, -6.05, -3.30, |
| 25 | + 5.36, -4.44, 1.08, -0.45, 2.58, -2.70, 0.27, |
| 26 | + 9.04, 8.32, 2.71, 4.35, -7.17, 2.14, -9.67, |
| 27 | + -5.14, -7.26, 6.08, -6.87, |
| 28 | + ] |
| 29 | + |
| 30 | + var b: [__CLPK_doublereal] = [ |
| 31 | + 4.02, 6.19, -8.22, -7.57, -3.03, -1.56, 4.00, -8.67, |
| 32 | + 1.75, 2.86, 9.81, -4.09, -4.57, -8.61, 8.99, |
| 33 | + ] |
| 34 | + |
| 35 | + dgesv_(&n, &nrhs, &a, &lda, &ipiv, &b, &ldb, &info) |
| 36 | + |
| 37 | + if info > 0 { |
| 38 | + Issue.record( |
| 39 | + "The diagonal element of the triangular factor of A, U, is zero. The factorization has been completed, but the factor U is exactly singular, so the solution could not be computed." |
| 40 | + ) |
| 41 | + return |
| 42 | + } |
| 43 | + |
| 44 | + #expect( |
| 45 | + b.map { |
| 46 | + ($0 * pow(10, 2)).rounded() / pow(10, 2) |
| 47 | + } == [-0.80, -0.70, 0.59, 1.32, 0.57, -0.39, -0.55, 0.84, -0.10, 0.11, 0.96, 0.22, 1.90, 5.36, 4.04] |
| 48 | + ) |
| 49 | + } |
| 50 | + |
| 51 | + // https://www.intel.com/content/www/us/en/docs/onemkl/code-samples-lapack/2022-1/dgesvd-example-c.html |
| 52 | + @Test("dgesvd_") |
| 53 | + func test_dgesvd_() { |
| 54 | + let M = 6 |
| 55 | + let N = 5 |
| 56 | + let LDA = M |
| 57 | + let LDU = M |
| 58 | + let LDVT = N |
| 59 | + |
| 60 | + var m = __CLPK_integer(M) |
| 61 | + var n = __CLPK_integer(N) |
| 62 | + var lda = __CLPK_integer(LDA) |
| 63 | + var ldu = __CLPK_integer(LDU) |
| 64 | + var ldvt = __CLPK_integer(LDVT) |
| 65 | + var info = __CLPK_integer(0) |
| 66 | + var lwork = __CLPK_integer(-1) |
| 67 | + var wkopt = __CLPK_doublereal(0) |
| 68 | + |
| 69 | + var a: [__CLPK_doublereal] = [ |
| 70 | + 8.79, 6.11, -9.15, 9.57, -3.49, 9.84, 9.93, 6.91, |
| 71 | + -7.93, 1.64, 4.02, 0.15, 9.83, 5.04, 4.86, 8.83, |
| 72 | + 9.80, -8.99, 5.45, -0.27, 4.85, 0.74, 10.00, -6.02, |
| 73 | + 3.16, 7.98, 3.01, 5.80, 4.27, -5.31, |
| 74 | + ] |
| 75 | + |
| 76 | + var s = [__CLPK_doublereal](repeating: 0, count: Int(min(m, n))) |
| 77 | + var u = [__CLPK_doublereal](repeating: 0, count: Int(ldu * m)) |
| 78 | + var vt = [__CLPK_doublereal](repeating: 0, count: Int(ldvt * n)) |
| 79 | + |
| 80 | + let jobu = "A" |
| 81 | + let jobvt = "A" |
| 82 | + |
| 83 | + jobu.withCString { jobuPtr in |
| 84 | + jobvt.withCString { jobvtPtr in |
| 85 | + let mutableJobuPtr = UnsafeMutablePointer(mutating: jobuPtr) |
| 86 | + let mutableJobvtPtr = UnsafeMutablePointer(mutating: jobvtPtr) |
| 87 | + |
| 88 | + dgesvd_(mutableJobuPtr, mutableJobvtPtr, &m, &n, &a, &lda, &s, &u, &ldu, &vt, &ldvt, &wkopt, &lwork, &info) |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + lwork = __CLPK_integer(wkopt) |
| 93 | + var work = [__CLPK_doublereal](repeating: 0, count: Int(lwork)) |
| 94 | + |
| 95 | + jobu.withCString { jobuPtr in |
| 96 | + jobvt.withCString { jobvtPtr in |
| 97 | + let mutableJobuPtr = UnsafeMutablePointer(mutating: jobuPtr) |
| 98 | + let mutableJobvtPtr = UnsafeMutablePointer(mutating: jobvtPtr) |
| 99 | + |
| 100 | + dgesvd_(mutableJobuPtr, mutableJobvtPtr, &m, &n, &a, &lda, &s, &u, &ldu, &vt, &ldvt, &work, &lwork, &info) |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + if info > 0 { |
| 105 | + Issue.record("The algorithm computing SVD failed to converge.") |
| 106 | + return |
| 107 | + } |
| 108 | + |
| 109 | + #expect( |
| 110 | + s.map { |
| 111 | + ($0 * pow(10, 2)).rounded() / pow(10, 2) |
| 112 | + } == [27.47, 22.64, 8.56, 5.99, 2.01] |
| 113 | + ) |
| 114 | + |
| 115 | + // Note: if we care about orthogonality of U and V we should check those too |
| 116 | + } |
| 117 | +} |
| 118 | +#endif |
0 commit comments