Skip to content

Commit d24de3b

Browse files
Katharine Hyattkshyatt
authored andcommitted
Sparse GPU array and broadcasting support
1 parent 8a27677 commit d24de3b

File tree

11 files changed

+1840
-2
lines changed

11 files changed

+1840
-2
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1414
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1515
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
16+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718

1819
[weakdeps]
@@ -33,5 +34,6 @@ Random = "1"
3334
Reexport = "1"
3435
ScopedValues = "1"
3536
Serialization = "1"
37+
SparseArrays = "1"
3638
Statistics = "1"
3739
julia = "1.10"

README.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,63 @@ This package is the counterpart of Julia's `AbstractArray` interface, but for GP
3131
types: It provides functionality and tooling to speed-up development of new GPU array types.
3232
**This package is not intended for end users!** Instead, you should use one of the packages
3333
that builds on GPUArrays.jl, such as [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [oneAPI.jl](https://github.com/JuliaGPU/oneAPI.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), or [Metal.jl](https://github.com/JuliaGPU/Metal.jl).
34+
35+
## Interface methods
36+
37+
To support a new GPU backend, you will need to implement various interface methods for your backend's array types.
38+
Some (CPU based) examples can be see in the testing library `JLArrays` (located in the `lib` directory of this package).
39+
40+
### Dense array support
41+
42+
### Sparse array support (optional)
43+
44+
`GPUArrays.jl` provides **device-side** array types for `CSC`, `CSR`, `COO`, and `BSR` matrices, as well as sparse vectors.
45+
It also provides abstract types for these layouts that you can create concrete child types of in order to benefit from the
46+
backend-agnostic wrappers. In particular, `GPUArrays.jl` provides out-of-the-box support for broadcasting and `mapreduce` over
47+
GPU sparse arrays.
48+
49+
For **host-side** types, your custom sparse types should implement:
50+
51+
- `dense_array_type` - the corresponding dense array type. For example, for a `CuSparseVector` or `CuSparseMatrixCXX`, the `dense_array_type` is `CuArray`
52+
- `sparse_array_type` - the **untyped** sparse array type corresponding to a given parametrized type. A `CuSparseVector{Tv, Ti}` would have a `sparse_array_type` of `CuVector` -- note the lack of type parameters!
53+
- `csc_type(::Type{T})` - the compressed sparse column type for your backend. A `CuSparseMatrixCSR` would have a `csc_type` of `CuSparseMatrixCSC`.
54+
- `csr_type(::Type{T})` - the compressed sparse row type for your backend. A `CuSparseMatrixCSC` would have a `csr_type` of `CuSparseMatrixCSR`.
55+
- `coo_type(::Type{T})` - the coordinate sparse matrix type for your backend. A `CuSparseMatrixCSC` would have a `coo_type` of `CuSparseMatrixCOO`.
56+
57+
To use `SparseArrays.findnz`, your host-side type **must** implement `sortperm`. This can be done with scalar indexing, but will be very slow.
58+
59+
Additionally, you need to teach `GPUArrays.jl` how to translate your backend's specific types onto the device. `GPUArrays.jl` provides the device-side types:
60+
61+
- `GPUSparseDeviceVector`
62+
- `GPUSparseDeviceMatrixCSC`
63+
- `GPUSparseDeviceMatrixCSR`
64+
- `GPUSparseDeviceMatrixBSR`
65+
- `GPUSparseDeviceMatrixCOO`
66+
67+
You will need to create a method of `Adapt.adapt_structure` for each format your backend supports. **Note** that if your backend supports separate address spaces,
68+
as CUDA and ROCm do, you need to provide a parameter to these device-side arrays to indicate in which address space the underlying pointers live. An example of adapting
69+
an array to the device-side struct:
70+
71+
```julia
72+
function GPUArrays.GPUSparseDeviceVector(iPtr::MyDeviceVector{Ti, A},
73+
nzVal::MyDeviceVector{Tv, A},
74+
len::Int,
75+
nnz::Ti) where {Ti, Tv, A}
76+
GPUArrays.GPUSparseDeviceVector{Tv, Ti, MyDeviceVector{Ti, A}, MyDeviceVector{Tv, A}, A}(iPtr, nzVal, len, nnz)
77+
end
78+
79+
function Adapt.adapt_structure(to::MyAdaptor, x::MySparseVector)
80+
return GPUArrays.GPUSparseDeviceVector(
81+
adapt(to, x.iPtr),
82+
adapt(to, x.nzVal),
83+
length(x), x.nnz
84+
)
85+
end
86+
```
87+
88+
You'll also need to inform `GPUArrays.jl` and `GPUCompiler.jl` how to adapt your sparse arrays by extending `KernelAbstractions.jl`'s `get_backend()`:
89+
90+
```julia
91+
KA.get_backend(::MySparseVector) = MyBackend()
92+
```
93+
```

lib/JLArrays/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ version = "0.3.0"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
99
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1113

1214
[compat]
1315
Adapt = "2.0, 3.0, 4.0"
1416
GPUArrays = "11.1"
1517
KernelAbstractions = "0.9, 0.10"
18+
LinearAlgebra = "1"
1619
Random = "1"
20+
SparseArrays = "1"
1721
julia = "1.8"

lib/JLArrays/src/JLArrays.jl

Lines changed: 224 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77
module JLArrays
88

9-
export JLArray, JLVector, JLMatrix, jl, JLBackend
9+
export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR
1010

1111
using GPUArrays
1212

1313
using Adapt
14+
using SparseArrays, LinearAlgebra
15+
16+
import GPUArrays: dense_array_type
1417

1518
import KernelAbstractions
1619
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
@@ -19,6 +22,11 @@ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, Dyn
1922
import KernelAbstractions: POCL
2023
end
2124

25+
module AS
26+
27+
const Generic = 0
28+
29+
end
2230

2331
#
2432
# Device functionality
@@ -115,7 +123,141 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
115123
end
116124
end
117125

126+
mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti}
127+
iPtr::JLArray{Ti, 1}
128+
nzVal::JLArray{Tv, 1}
129+
len::Int
130+
nnz::Ti
131+
132+
function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
133+
len::Integer) where {Tv, Ti <: Integer}
134+
new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
135+
end
136+
end
137+
SparseArrays.nnz(x::JLSparseVector) = x.nnz
138+
SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
139+
SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
140+
141+
mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
142+
colPtr::JLArray{Ti, 1}
143+
rowVal::JLArray{Ti, 1}
144+
nzVal::JLArray{Tv, 1}
145+
dims::NTuple{2,Int}
146+
nnz::Ti
147+
148+
function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
149+
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
150+
new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
151+
end
152+
end
153+
function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
154+
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
155+
end
156+
SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
157+
158+
JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
159+
160+
function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
161+
@boundscheck checkbounds(A, i, j)
162+
r1 = Int(@inbounds A.colPtr[j])
163+
r2 = Int(@inbounds A.colPtr[j+1]-1)
164+
(r1 > r2) && return zero(Tv)
165+
r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
166+
((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
167+
end
168+
169+
mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
170+
rowPtr::JLArray{Ti, 1}
171+
colVal::JLArray{Ti, 1}
172+
nzVal::JLArray{Tv, 1}
173+
dims::NTuple{2,Int}
174+
nnz::Ti
175+
176+
function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
177+
nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
178+
new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
179+
end
180+
end
181+
function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
182+
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
183+
end
184+
function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
185+
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
186+
return SparseMatrixCSC(transpose(x_transpose))
187+
end
188+
189+
JLSparseMatrixCSC(Mat::Union{Transpose{Tv, <:SparseMatrixCSC}, Adjoint{Tv, <:SparseMatrixCSC}}) where {Tv} = JLSparseMatrixCSC(JLSparseMatrixCSR(Mat))
190+
191+
function Base.size(g::JLSparseMatrixCSR, d::Integer)
192+
if 1 <= d <= 2
193+
return g.dims[d]
194+
elseif d > 1
195+
return 1
196+
else
197+
throw(ArgumentError("dimension must be ≥ 1, got $d"))
198+
end
199+
end
200+
201+
JLSparseMatrixCSR(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {Tv} =
202+
JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
203+
JLVector(parent(Mat).nzval), size(Mat))
204+
JLSparseMatrixCSR(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {Tv} =
205+
JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
206+
JLVector(conj.(parent(Mat).nzval)), size(Mat))
207+
208+
JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A
209+
210+
function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) where {Tv, Ti}
211+
@boundscheck checkbounds(A, i0, i1)
212+
c1 = Int(A.rowPtr[i0])
213+
c2 = Int(A.rowPtr[i0+1]-1)
214+
(c1 > c2) && return zero(Tv)
215+
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
216+
(c1 > c2 || A.colVal[c1] != i1) && return zero(Tv)
217+
nonzeros(A)[c1]
218+
end
219+
118220
GPUArrays.storage(a::JLArray) = a.data
221+
GPUArrays.dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N}
222+
GPUArrays.dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N}
223+
GPUArrays.dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1}
224+
GPUArrays.dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1}
225+
226+
GPUArrays.sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC
227+
GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC
228+
GPUArrays.sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR
229+
GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
230+
GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector
231+
GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
232+
233+
GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
234+
GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
235+
GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
236+
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
237+
GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
238+
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
239+
240+
GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
241+
GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
242+
243+
Base.similar(Mat::JLSparseMatrixCSR) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
244+
Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat), T), size(Mat))
245+
246+
Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
247+
Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
248+
249+
Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
250+
Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
251+
252+
Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
253+
Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
254+
255+
Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
256+
Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
257+
258+
JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
259+
JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x)))
260+
JLArray(x::JLSparseMatrixCSR) = JLArray(collect(SparseMatrixCSC(x)))
119261

120262
# conversion of untyped data to a typed Array
121263
function typed_data(x::JLArray{T}) where {T}
@@ -217,6 +359,79 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
217359
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
218360
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
219361

362+
function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
363+
iPtr = JLVector{Ti}(undef, length(xs.nzind))
364+
nzVal = JLVector{Tv}(undef, length(xs.nzval))
365+
copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
366+
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
367+
return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
368+
end
369+
Base.length(x::JLSparseVector) = x.len
370+
Base.size(x::JLSparseVector) = (x.len,)
371+
372+
function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
373+
colPtr = JLVector{Ti}(undef, length(xs.colptr))
374+
rowVal = JLVector{Ti}(undef, length(xs.rowval))
375+
nzVal = JLVector{Tv}(undef, length(xs.nzval))
376+
copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
377+
copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
378+
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
379+
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
380+
end
381+
JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
382+
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
383+
Base.size(x::JLSparseMatrixCSC) = x.dims
384+
385+
function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
386+
csr_xs = SparseMatrixCSC(transpose(xs))
387+
rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
388+
colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
389+
nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
390+
copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
391+
copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
392+
copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
393+
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
394+
end
395+
JLSparseMatrixCSR(xs::SparseVector{Tv, Ti}) where {Ti, Tv} = JLSparseMatrixCSR(SparseMatrixCSC(xs))
396+
function JLSparseMatrixCSR(xs::JLSparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
397+
return JLSparseMatrixCSR(SparseMatrixCSC(xs))
398+
end
399+
function JLSparseMatrixCSC(xs::JLSparseMatrixCSR{Tv, Ti}) where {Ti, Tv}
400+
return JLSparseMatrixCSC(SparseMatrixCSC(xs))
401+
end
402+
function Base.copyto!(dst::JLSparseMatrixCSR, src::JLSparseMatrixCSR)
403+
if size(dst) != size(src)
404+
throw(ArgumentError("Inconsistent Sparse Matrix size"))
405+
end
406+
resize!(dst.rowPtr, length(src.rowPtr))
407+
resize!(dst.colVal, length(src.colVal))
408+
resize!(SparseArrays.nonzeros(dst), length(SparseArrays.nonzeros(src)))
409+
copyto!(dst.rowPtr, src.rowPtr)
410+
copyto!(dst.colVal, src.colVal)
411+
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
412+
dst.nnz = src.nnz
413+
dst
414+
end
415+
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
416+
Base.size(x::JLSparseMatrixCSR) = x.dims
417+
418+
function GPUArrays._spadjoint(A::JLSparseMatrixCSR)
419+
Aᴴ = JLSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A)))
420+
JLSparseMatrixCSR(Aᴴ)
421+
end
422+
function GPUArrays._sptranspose(A::JLSparseMatrixCSR)
423+
Aᵀ = JLSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A)))
424+
JLSparseMatrixCSR(Aᵀ)
425+
end
426+
function _spadjoint(A::JLSparseMatrixCSC)
427+
Aᴴ = JLSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A)))
428+
JLSparseMatrixCSC(Aᴴ)
429+
end
430+
function _sptranspose(A::JLSparseMatrixCSC)
431+
Aᵀ = JLSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
432+
JLSparseMatrixCSC(Aᵀ)
433+
end
434+
220435
# idempotency
221436
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
222437

@@ -358,9 +573,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
358573
R
359574
end
360575

576+
Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
577+
GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
578+
Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
579+
GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
580+
Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
581+
GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
582+
361583
## KernelAbstractions interface
362584

363585
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
586+
KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
364587

365588
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
366589
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)

src/GPUArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using KernelAbstractions
1919

2020
# device functionality
2121
include("device/abstractarray.jl")
22+
include("device/sparse.jl")
2223

2324
# host abstractions
2425
include("host/abstractarray.jl")
@@ -34,6 +35,7 @@ include("host/random.jl")
3435
include("host/quirks.jl")
3536
include("host/uniformscaling.jl")
3637
include("host/statistics.jl")
38+
include("host/sparse.jl")
3739
include("host/alloc_cache.jl")
3840

3941

0 commit comments

Comments
 (0)