Skip to content

Commit 1c30641

Browse files
committed
Fixes
1 parent b5f6994 commit 1c30641

File tree

4 files changed

+39
-20
lines changed

4 files changed

+39
-20
lines changed

lib/mkl/array.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
55
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
66

77
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
8-
handle::matrix_handle_t
8+
handle::Union{Nothing, matrix_handle_t}
99
rowPtr::oneVector{Ti}
1010
colVal::oneVector{Ti}
1111
nzVal::oneVector{Tv}
@@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1414
end
1515

1616
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17-
handle::matrix_handle_t
17+
handle::Union{Nothing, matrix_handle_t}
1818
colPtr::oneVector{Ti}
1919
rowVal::oneVector{Ti}
2020
nzVal::oneVector{Tv}
@@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
2323
end
2424

2525
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
26-
handle::matrix_handle_t
26+
handle::Union{Nothing, matrix_handle_t}
2727
rowInd::oneVector{Ti}
2828
colInd::oneVector{Ti}
2929
nzVal::oneVector{Tv}

lib/mkl/wrappers_sparse.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function sparse_release_matrix_handle(A::oneAbstractSparseMatrix)
2-
queue = global_queue(context(A.nzVal), device(A.nzVal))
3-
m, n = size(A)
4-
return if m != 0 && n != 0
2+
return if A.handle !== nothing
3+
queue = global_queue(context(A.nzVal), device(A.nzVal))
4+
oneL0.synchronize(queue)
55
handle_ptr = Ref{matrix_handle_t}(A.handle)
66
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
77
end
@@ -29,9 +29,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
2929
# Don't update handle if matrix is empty
3030
if m != 0 && n != 0
3131
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
32+
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
33+
finalizer(spars_release_matrix_handle, dA)
34+
else
35+
dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA)
3236
end
33-
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
34-
finalizer(sparse_release_matrix_handle, dA)
3537
return dA
3638
end
3739

@@ -47,9 +49,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
4749
# Don't update handle if matrix is empty
4850
if m != 0 && n != 0
4951
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
52+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, dims, nnzA)
53+
finalizer(sparse_release_matrix_handle, dA)
54+
else
55+
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, dims, nnzA)
5056
end
51-
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, dims, nnzA)
52-
finalizer(sparse_release_matrix_handle, dA)
5357
return dA
5458
end
5559

@@ -63,7 +67,6 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
6367
end
6468

6569
function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
66-
handle_ptr = Ref{matrix_handle_t}()
6770
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
6871
A_csc = SparseMatrixCSC(At |> transpose)
6972
return A_csc
@@ -78,7 +81,6 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
7881
end
7982

8083
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
81-
handle_ptr = Ref{matrix_handle_t}()
8284
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
8385
return A_csc
8486
end
@@ -104,14 +106,17 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
104106
nzVal = oneVector{$elty}(val)
105107
nnzA = length(val)
106108
queue = global_queue(context(nzVal), device(nzVal))
107-
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
108-
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
109-
finalizer(sparse_release_matrix_handle, dA)
109+
if m != 0 && n != 0
110+
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
111+
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
112+
finalizer(sparse_release_matrix_handle, dA)
113+
else
114+
dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA)
115+
end
110116
return dA
111117
end
112118

113119
function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
114-
handle_ptr = Ref{matrix_handle_t}()
115120
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
116121
return A
117122
end

src/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool})
2020
I = keytype(bools)
2121

2222
indices = cumsum(reshape(bools, prod(size(bools))))
23-
oneL0.synchronize()
2423

2524
n = isempty(indices) ? 0 : @allowscalar indices[end]
2625

2726
ys = oneArray{I}(undef, n)
2827

2928
if n > 0
30-
@oneapi items = length(bools) _ker!(ys, bools, indices)
29+
kernel = @oneapi launch=false _ker!(ys, bools, indices)
30+
group_size = launch_configuration(kernel)
31+
kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size))
3132
end
32-
oneL0.synchronize()
33-
unsafe_free!(indices)
33+
# unsafe_free!(indices)
3434

3535
return ys
3636
end

test/indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ using oneAPI
1717
data = oneArray(collect(1:6))
1818
mask = oneArray(Bool[true, false, true, false, false, true])
1919
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
20+
21+
# Test with array larger than 1024 to trigger multiple groups
22+
large_size = 2048
23+
large_mask = oneArray(rand(Bool, large_size))
24+
large_result_gpu = Array(findall(large_mask))
25+
large_result_cpu = findall(Array(large_mask))
26+
@test large_result_gpu == large_result_cpu
27+
28+
# Test with even larger array to ensure robustness
29+
very_large_size = 5000
30+
very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result
31+
very_large_result_gpu = Array(findall(very_large_mask))
32+
very_large_result_cpu = findall(fill(true, very_large_size))
33+
@test very_large_result_gpu == very_large_result_cpu
2034
end

0 commit comments

Comments
 (0)