You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/mkl/interfaces.jl
+245-4Lines changed: 245 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
7
7
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
8
8
end
9
9
10
-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <:BlasReal
11
-
tA = tA in ('S', 's', 'H', 'h') ?'T':flip_trans(tA)
10
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
11
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
12
+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13
+
end
14
+
15
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
16
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
12
17
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13
18
end
14
19
@@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
18
23
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
19
24
end
20
25
21
-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasReal
22
-
tA = tA in ('S', 's', 'H', 'h') ?'T':flip_trans(tA)
26
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
27
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
28
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
29
+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
30
+
end
31
+
32
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
33
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
23
34
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
24
35
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
25
36
end
@@ -31,3 +42,233 @@ end
31
42
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <:BlasFloat
32
43
sparse_trsm!(uploc, tfun === identity ?'N': tfun === transpose ?'T':'C', 'N', isunitc, one(T), A, B, C)
33
44
end
45
+
46
+
# Handle Transpose and Adjoint wrappers for sparse matrices
47
+
# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic
48
+
49
+
# Matrix-vector multiplication with transpose/adjoint
50
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
51
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
52
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
53
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
54
+
end
55
+
56
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
57
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
58
+
if tA =='T'
59
+
alpha = _add.alpha
60
+
beta = _add.beta
61
+
B .=conj.(B)
62
+
C .=conj.(C)
63
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
64
+
C .=conj.(C)
65
+
B .=conj.(B)
66
+
else
67
+
tA_final = tA =='N'?'C':'N'
68
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
69
+
end
70
+
return C
71
+
end
72
+
73
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
74
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
75
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
76
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
77
+
end
78
+
79
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
80
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
81
+
if tA =='T'
82
+
alpha = _add.alpha
83
+
beta = _add.beta
84
+
B .=conj.(B)
85
+
C .=conj.(C)
86
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
87
+
C .=conj.(C)
88
+
B .=conj.(B)
89
+
else
90
+
tA_final = tA =='N'?'C':'N'
91
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
92
+
end
93
+
return C
94
+
end
95
+
96
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
97
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
98
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
99
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
100
+
end
101
+
102
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasFloat
103
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
104
+
if tA =='T'
105
+
alpha = _add.alpha
106
+
beta = _add.beta
107
+
B .=conj.(B)
108
+
C .=conj.(C)
109
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
110
+
C .=conj.(C)
111
+
B .=conj.(B)
112
+
else
113
+
tA_final = tA =='N'?'C':'N'
114
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
115
+
end
116
+
return C
117
+
end
118
+
119
+
# Handle Transpose{T, Adjoint{T, ...}} for complex matrices
120
+
# transpose(adjoint(A)) for complex matrices needs special handling
121
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasComplex
122
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
123
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
124
+
alpha = _add.alpha
125
+
beta = _add.beta
126
+
B .=conj.(B)
127
+
C .=conj.(C)
128
+
if tA =='N'
129
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
130
+
elseif tA =='T'
131
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
132
+
else# tA == 'C'
133
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
134
+
end
135
+
C .=conj.(C)
136
+
B .=conj.(B)
137
+
return C
138
+
end
139
+
140
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasComplex
141
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
142
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
143
+
alpha = _add.alpha
144
+
beta = _add.beta
145
+
B .=conj.(B)
146
+
C .=conj.(C)
147
+
if tA =='N'
148
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
149
+
elseif tA =='T'
150
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
151
+
else# tA == 'C'
152
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
153
+
end
154
+
C .=conj.(C)
155
+
B .=conj.(B)
156
+
return C
157
+
end
158
+
159
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where T <:BlasComplex
160
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
161
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
162
+
alpha = _add.alpha
163
+
beta = _add.beta
164
+
B .=conj.(B)
165
+
C .=conj.(C)
166
+
if tA =='N'
167
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
168
+
elseif tA =='T'
169
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
170
+
else# tA == 'C'
171
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
172
+
end
173
+
C .=conj.(C)
174
+
B .=conj.(B)
175
+
return C
176
+
end
177
+
178
+
# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation
179
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where T <:BlasComplex
180
+
m, n =size(A)
181
+
y =similar(x, T, m)
182
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
183
+
return y
184
+
end
185
+
186
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where T <:BlasComplex
187
+
m, n =size(A)
188
+
y =similar(x, T, m)
189
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
190
+
return y
191
+
end
192
+
193
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where T <:BlasComplex
194
+
m, n =size(A)
195
+
y =similar(x, T, m)
196
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
197
+
return y
198
+
end
199
+
200
+
# Matrix-matrix multiplication with transpose/adjoint
201
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
202
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
203
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
204
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
205
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
206
+
end
207
+
208
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
209
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
210
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
211
+
if tA =='T'
212
+
alpha = _add.alpha
213
+
beta = _add.beta
214
+
B .=conj.(B)
215
+
C .=conj.(C)
216
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
217
+
C .=conj.(C)
218
+
B .=conj.(B)
219
+
else
220
+
tA_final = tA =='N'?'C':'N'
221
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
222
+
end
223
+
return C
224
+
end
225
+
226
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
227
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
228
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
229
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
230
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
231
+
end
232
+
233
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
234
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
235
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
236
+
if tA =='T'
237
+
alpha = _add.alpha
238
+
beta = _add.beta
239
+
B .=conj.(B)
240
+
C .=conj.(C)
241
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
242
+
C .=conj.(C)
243
+
B .=conj.(B)
244
+
else
245
+
tA_final = tA =='N'?'C':'N'
246
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
247
+
end
248
+
return C
249
+
end
250
+
251
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
252
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
253
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
254
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
255
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
256
+
end
257
+
258
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasFloat
259
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
260
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
261
+
if tA =='T'
262
+
alpha = _add.alpha
263
+
beta = _add.beta
264
+
B .=conj.(B)
265
+
C .=conj.(C)
266
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
267
+
C .=conj.(C)
268
+
B .=conj.(B)
269
+
else
270
+
tA_final = tA =='N'?'C':'N'
271
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
0 commit comments