Skip to content

Commit ca73b5f

Browse files
committed
Fix sparse gemm and gemv
1 parent ff1602e commit ca73b5f

File tree

4 files changed

+313
-9
lines changed

4 files changed

+313
-9
lines changed

lib/mkl/interfaces.jl

Lines changed: 245 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
77
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
88
end
99

10-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
11-
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
10+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
11+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
12+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13+
end
14+
15+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
16+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1217
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
1318
end
1419

@@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
1823
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
1924
end
2025

21-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
22-
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
26+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
27+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
28+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
29+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
30+
end
31+
32+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
33+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
2334
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
2435
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
2536
end
@@ -31,3 +42,233 @@ end
3142
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
3243
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
3344
end
45+
46+
# Handle Transpose and Adjoint wrappers for sparse matrices
47+
# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic
48+
49+
# Matrix-vector multiplication with transpose/adjoint
50+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
51+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
52+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
53+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
54+
end
55+
56+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
57+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
58+
if tA == 'T'
59+
alpha = _add.alpha
60+
beta = _add.beta
61+
B .= conj.(B)
62+
C .= conj.(C)
63+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
64+
C .= conj.(C)
65+
B .= conj.(B)
66+
else
67+
tA_final = tA == 'N' ? 'C' : 'N'
68+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
69+
end
70+
return C
71+
end
72+
73+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
74+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
75+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
76+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
77+
end
78+
79+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
80+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
81+
if tA == 'T'
82+
alpha = _add.alpha
83+
beta = _add.beta
84+
B .= conj.(B)
85+
C .= conj.(C)
86+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
87+
C .= conj.(C)
88+
B .= conj.(B)
89+
else
90+
tA_final = tA == 'N' ? 'C' : 'N'
91+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
92+
end
93+
return C
94+
end
95+
96+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
97+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
98+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
99+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
100+
end
101+
102+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
103+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
104+
if tA == 'T'
105+
alpha = _add.alpha
106+
beta = _add.beta
107+
B .= conj.(B)
108+
C .= conj.(C)
109+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
110+
C .= conj.(C)
111+
B .= conj.(B)
112+
else
113+
tA_final = tA == 'N' ? 'C' : 'N'
114+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
115+
end
116+
return C
117+
end
118+
119+
# Handle Transpose{T, Adjoint{T, ...}} for complex matrices
120+
# transpose(adjoint(A)) for complex matrices needs special handling
121+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasComplex
122+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
123+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
124+
alpha = _add.alpha
125+
beta = _add.beta
126+
B .= conj.(B)
127+
C .= conj.(C)
128+
if tA == 'N'
129+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
130+
elseif tA == 'T'
131+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
132+
else # tA == 'C'
133+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
134+
end
135+
C .= conj.(C)
136+
B .= conj.(B)
137+
return C
138+
end
139+
140+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasComplex
141+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
142+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
143+
alpha = _add.alpha
144+
beta = _add.beta
145+
B .= conj.(B)
146+
C .= conj.(C)
147+
if tA == 'N'
148+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
149+
elseif tA == 'T'
150+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
151+
else # tA == 'C'
152+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
153+
end
154+
C .= conj.(C)
155+
B .= conj.(B)
156+
return C
157+
end
158+
159+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where T <: BlasComplex
160+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
161+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
162+
alpha = _add.alpha
163+
beta = _add.beta
164+
B .= conj.(B)
165+
C .= conj.(C)
166+
if tA == 'N'
167+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
168+
elseif tA == 'T'
169+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
170+
else # tA == 'C'
171+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
172+
end
173+
C .= conj.(C)
174+
B .= conj.(B)
175+
return C
176+
end
177+
178+
# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation
179+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where T <: BlasComplex
180+
m, n = size(A)
181+
y = similar(x, T, m)
182+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
183+
return y
184+
end
185+
186+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where T <: BlasComplex
187+
m, n = size(A)
188+
y = similar(x, T, m)
189+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
190+
return y
191+
end
192+
193+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where T <: BlasComplex
194+
m, n = size(A)
195+
y = similar(x, T, m)
196+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
197+
return y
198+
end
199+
200+
# Matrix-matrix multiplication with transpose/adjoint
201+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
202+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
203+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
204+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
205+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
206+
end
207+
208+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
209+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
210+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
211+
if tA == 'T'
212+
alpha = _add.alpha
213+
beta = _add.beta
214+
B .= conj.(B)
215+
C .= conj.(C)
216+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
217+
C .= conj.(C)
218+
B .= conj.(B)
219+
else
220+
tA_final = tA == 'N' ? 'C' : 'N'
221+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
222+
end
223+
return C
224+
end
225+
226+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
227+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
228+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
229+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
230+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
231+
end
232+
233+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
234+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
235+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
236+
if tA == 'T'
237+
alpha = _add.alpha
238+
beta = _add.beta
239+
B .= conj.(B)
240+
C .= conj.(C)
241+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
242+
C .= conj.(C)
243+
B .= conj.(B)
244+
else
245+
tA_final = tA == 'N' ? 'C' : 'N'
246+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
247+
end
248+
return C
249+
end
250+
251+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
252+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
253+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
254+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
255+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
256+
end
257+
258+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
259+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
260+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
261+
if tA == 'T'
262+
alpha = _add.alpha
263+
beta = _add.beta
264+
B .= conj.(B)
265+
C .= conj.(C)
266+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
267+
C .= conj.(C)
268+
B .= conj.(B)
269+
else
270+
tA_final = tA == 'N' ? 'C' : 'N'
271+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
272+
end
273+
return C
274+
end

lib/mkl/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,5 @@ end
113113
ptrs = pointer.(batch)
114114
return oneArray(ptrs)
115115
end
116-
117116
flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
118117
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'

lib/mkl/wrappers_sparse.jl

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,47 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
1313
(:onemklZsparse_set_csr_data , :ComplexF64, :Int32),
1414
(:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64))
1515
@eval begin
16+
17+
function oneSparseMatrixCSR(
18+
rowPtr::oneVector{$intty}, colVal::oneVector{$intty},
19+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
20+
)
21+
handle_ptr = Ref{matrix_handle_t}()
22+
onemklXsparse_init_matrix_handle(handle_ptr)
23+
m, n = dims
24+
nnzA = length(nzVal)
25+
queue = global_queue(context(nzVal), device(nzVal))
26+
# Don't update handle if matrix is empty
27+
if m != 0 && n != 0
28+
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
29+
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
30+
finalizer(sparse_release_matrix_handle, dA)
31+
else
32+
dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA)
33+
end
34+
return dA
35+
end
36+
37+
function oneSparseMatrixCSC(
38+
colPtr::oneVector{$intty}, rowVal::oneVector{$intty},
39+
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
40+
)
41+
queue = global_queue(context(nzVal), device(nzVal))
42+
handle_ptr = Ref{matrix_handle_t}()
43+
onemklXsparse_init_matrix_handle(handle_ptr)
44+
m, n = dims
45+
nnzA = length(nzVal)
46+
# Don't update handle if matrix is empty
47+
if m != 0 && n != 0
48+
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
49+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
50+
finalizer(sparse_release_matrix_handle, dA)
51+
else
52+
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, (m,n), nnzA)
53+
end
54+
return dA
55+
end
56+
1657
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
1758
handle_ptr = Ref{matrix_handle_t}()
1859
onemklXsparse_init_matrix_handle(handle_ptr)
@@ -140,8 +181,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
140181
beta::Number,
141182
y::oneStridedVector{$elty})
142183

143-
queue = global_queue(context(x), device())
144-
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
184+
queue = global_queue(context(x), device(x))
185+
m, n = size(A)
186+
if m != 0 && n != 0
187+
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
188+
end
145189
y
146190
end
147191
end

test/onemkl.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,10 @@ end
10901090
B = oneSparseMatrixCSR(A)
10911091
A2 = SparseMatrixCSC(B)
10921092
@test A == A2
1093+
C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B))
1094+
A3 = SparseMatrixCSC(C)
1095+
@test A == A3
1096+
D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
10931097
end
10941098
end
10951099

@@ -1101,6 +1105,13 @@ end
11011105
B = oneSparseMatrixCSC(A)
11021106
A2 = SparseMatrixCSC(B)
11031107
@test A == A2
1108+
<<<<<<< HEAD
1109+
=======
1110+
C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A))
1111+
A3 = SparseMatrixCSC(C)
1112+
@test A == A3
1113+
D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
1114+
>>>>>>> 15e7b9c (Fix sparse gemm and gemv)
11041115
end
11051116
end
11061117

@@ -1129,10 +1140,17 @@ end
11291140
beta = rand(T)
11301141
oneMKL.sparse_optimize_gemv!(transa, dA)
11311142
oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
1132-
@test alpha * opa(A) * x + beta * y collect(dy)
1143+
@test alpha * opa(A) * x + beta * y collect(dy)
1144+
dy = oneVector{T}(y)
1145+
@test alpha * opa(A) * x + beta * y Array(alpha * opa(dA) * dx + beta * dy)
1146+
tx = transa == 'N' ? rand(T, 20) : rand(T, 10)
1147+
ty = transa == 'N' ? rand(T, 10) : rand(T, 20)
1148+
dtx = oneVector{T}(tx)
1149+
dty = oneVector{T}(ty)
1150+
t = @test alpha * opa(A') * tx + beta * ty Array(alpha * opa(dA') * dtx + beta * dty)
11331151
end
1134-
end
11351152
end
1153+
end
11361154

11371155
@testset "sparse gemm" begin
11381156
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
@@ -1153,6 +1171,8 @@ end
11531171
oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC)
11541172

11551173
@test alpha * opa(A) * opb(B) + beta * C collect(dC)
1174+
dC = oneMatrix{T}(C)
1175+
@test alpha * opa(A) * opb(B) + beta * C Array(alpha * opa(dA) * opb(dB) + beta * dC)
11561176
oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA)
11571177
end
11581178
end

0 commit comments

Comments
 (0)