Skip to content

Commit 83b4218

Browse files
authored
Rework find functions (#8)
* Change the bounds checking behaviour of the find* functions to match those of `Vector`. * Add an optimised generic fallback which, unlike the AbstractArray fallbacl, does not boundscheck in its loop body * Add a fastpath for findprev to dispatch to Libc's memrchr * More thoroughly test find functions
1 parent cbab8ab commit 83b4218

File tree

3 files changed

+175
-27
lines changed

3 files changed

+175
-27
lines changed

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ authors = ["Jakob Nybo Nissen <[email protected]>"]
66
[weakdeps]
77
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"
88

9+
[extensions]
10+
StringViewsExt = "StringViews"
11+
912
[compat]
1013
Aqua = "0.8.7"
1114
StringViews = "1"
1215
Test = "1.11"
1316
julia = "1.11"
1417

15-
[extensions]
16-
StringViewsExt = "StringViews"
17-
1818
[extras]
1919
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2020
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"

src/basic.jl

+103-11
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ function Base.getindex(v::MemoryView, i::Integer)
4343
@inbounds ref[]
4444
end
4545

46-
function Base.similar(
47-
mem::MemoryView{T1, M},
48-
::Type{T2},
49-
dims::Tuple{Int},
50-
) where {T1, T2, M}
46+
function Base.similar(::MemoryView{T1, M}, ::Type{T2}, dims::Tuple{Int}) where {T1, T2, M}
5147
len = Int(only(dims))::Int
5248
memory = Memory{T2}(undef, len)
5349
MemoryView{T2, M}(unsafe, memoryref(memory), len)
@@ -89,6 +85,23 @@ end
8985
Base.getindex(v::MemoryView, ::Colon) = v
9086
Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]
9187

88+
function truncate(mem::MemoryView, include_last::Integer)
89+
lst = Int(include_last)::Int
90+
@boundscheck if (lst % UInt) > length(mem) % UInt
91+
throw(BoundsError(mem, lst))
92+
end
93+
typeof(mem)(unsafe, mem.ref, lst)
94+
end
95+
96+
function truncate_start_nonempty(mem::MemoryView, from::Integer)
97+
frm = Int(from)::Int
98+
@boundscheck if ((frm - 1) % UInt) length(mem) % UInt
99+
throw(BoundsError(mem, frm))
100+
end
101+
newref = @inbounds memoryref(mem.ref, frm)
102+
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
103+
end
104+
92105
function Base.unsafe_copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
93106
iszero(length(src)) && return dst
94107
@inbounds unsafe_copyto!(dst.ref, src.ref, length(src))
@@ -105,6 +118,17 @@ function Base.copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
105118
unsafe_copyto!(dst, src)
106119
end
107120

121+
# Optimised methods that don't boundscheck
122+
function Base.findnext(p::Function, mem::MemoryView, start::Integer)
123+
i = Int(start)::Int
124+
@boundscheck (i < 1 && throw(BoundsError(mem, i)))
125+
@inbounds while i <= length(mem)
126+
p(mem[i]) && return i
127+
i += 1
128+
end
129+
nothing
130+
end
131+
108132
# The following two methods could be collapsed, but they aren't for two reasons:
109133
# * To prevent ambiguity with Base
110134
# * Because we DON'T want this code to run with MemoryView{Union{UInt8, Int8}}.
@@ -126,16 +150,25 @@ function Base.findnext(
126150
_findnext(mem, p.x, start)
127151
end
128152

129-
@inline function _findnext(
153+
function Base.findnext(
154+
::typeof(iszero),
155+
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
156+
i::Integer,
157+
)
158+
_findnext(mem, zero(eltype(mem)), i)
159+
end
160+
161+
Base.@propagate_inbounds function _findnext(
130162
mem::MemoryView{T},
131-
byte::Union{T},
163+
byte::T,
132164
start::Integer,
133165
) where {T <: Union{UInt8, Int8}}
134166
start = Int(start)::Int
135-
real_start = max(start, 1)
136-
v = @inbounds ImmutableMemoryView(mem[real_start:end])
137-
v_ind = @something memchr(v, byte) return nothing
138-
v_ind + real_start - 1
167+
@boundscheck(start < 1 && throw(BoundsError(mem, start)))
168+
start > length(mem) && return nothing
169+
im = @inbounds truncate_start_nonempty(ImmutableMemoryView(mem), start)
170+
v_ind = @something memchr(im, byte) return nothing
171+
v_ind + start - 1
139172
end
140173

141174
function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
@@ -151,6 +184,65 @@ function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UI
151184
p == C_NULL ? nothing : (p - ptr) % Int + 1
152185
end
153186

187+
function Base.findprev(p::Function, mem::MemoryView, start::Integer)
188+
i = Int(start)::Int
189+
@boundscheck (i > length(mem) && throw(BoundsError(mem, i)))
190+
@inbounds while i > 0
191+
p(mem[i]) && return i
192+
i -= 1
193+
end
194+
nothing
195+
end
196+
197+
function Base.findprev(
198+
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, UInt8},
199+
mem::MemoryView{UInt8},
200+
start::Integer,
201+
)
202+
_findprev(mem, p.x, start)
203+
end
204+
205+
function Base.findprev(
206+
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, Int8},
207+
mem::MemoryView{Int8},
208+
start::Integer,
209+
)
210+
_findprev(mem, p.x, start)
211+
end
212+
213+
function Base.findprev(
214+
::typeof(iszero),
215+
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
216+
i::Integer,
217+
)
218+
_findprev(mem, zero(eltype(mem)), i)
219+
end
220+
221+
Base.@propagate_inbounds function _findprev(
222+
mem::MemoryView{T},
223+
byte::T,
224+
start::Integer,
225+
) where {T <: Union{UInt8, Int8}}
226+
start = Int(start)::Int
227+
@boundscheck (start > length(mem) && throw(BoundsError(mem, start)))
228+
start < 1 && return nothing
229+
im = @inbounds truncate(ImmutableMemoryView(mem), start)
230+
memrchr(im, byte)
231+
end
232+
233+
function memrchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
234+
isempty(mem) && return nothing
235+
GC.@preserve mem begin
236+
ptr = Ptr{UInt8}(pointer(mem))
237+
p = @ccall memrchr(
238+
ptr::Ptr{UInt8},
239+
(byte % UInt8)::UInt8,
240+
length(mem)::Int,
241+
)::Ptr{Cvoid}
242+
end
243+
p == C_NULL ? nothing : (p - ptr) % Int + 1
244+
end
245+
154246
const Bits =
155247
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char}
156248

test/runtests.jl

+69-13
Original file line numberDiff line numberDiff line change
@@ -372,19 +372,6 @@ end
372372
@test v2 == v1
373373
end
374374

375-
@testset "Find" begin
376-
mem = MemoryView([4, 3, 2])
377-
@test findfirst(==(2), mem) == 3
378-
379-
mem = MemoryView(Int8[6, 2, 7, 0, 2])
380-
@test findfirst(iszero, mem) == 4
381-
@test findfirst(==(Int8(0)), mem) == 4
382-
383-
mem = MemoryView(UInt8[1, 4, 2, 5, 6])
384-
@test findnext(==(0x04), mem, 1) == 2
385-
@test findnext(==(0x04), mem, 3) === nothing
386-
end
387-
388375
@testset "Reverse and reverse!" begin
389376
for v in [
390377
["a", "abc", "a", "c", "kij"],
@@ -463,6 +450,75 @@ end
463450
@test split_unaligned(v, Val(8)) == split_at(v, 3)
464451
@test split_unaligned(v, Val(16)) == split_at(v, 7)
465452
end
453+
454+
@testset "Find" begin
455+
@testset "Generic find" begin
456+
mem = ImmutableMemoryView([1, 2, 3, 4])
457+
@test findfirst(isodd, mem) == 1
458+
@test findfirst(isodd, mem[2:end]) == 2
459+
@test findfirst(mem[1:0]) === nothing
460+
461+
@test findlast(isodd, mem) == 3
462+
@test findlast(isodd, mem[1:2]) == 1
463+
@test findlast(isodd, mem[1:0]) === nothing
464+
465+
@test findnext(isodd, mem, 0x02) == 3
466+
@test findnext(isodd, mem, 3) == 3
467+
@test findnext(isodd, mem, 0x04) === nothing
468+
@test findnext(isodd, mem, 10) === nothing
469+
470+
@test_throws BoundsError findnext(isodd, mem, 0)
471+
@test_throws BoundsError findnext(isodd, mem, -1)
472+
473+
@test findprev(isodd, mem, 4) == 3
474+
@test findprev(isodd, mem, 0x03) == 3
475+
@test findprev(isodd, mem, 2) == 1
476+
@test findprev(isodd, mem, 0x00) === nothing
477+
@test findprev(isodd, mem, -10) === nothing
478+
479+
@test_throws BoundsError findprev(isodd, mem, 5)
480+
@test_throws BoundsError findprev(isodd, mem, 7)
481+
end
482+
483+
@testset "Memchr routines" begin
484+
for T in Any[Int8, UInt8]
485+
mem = MemoryView(T[6, 2, 7, 0, 2, 1])
486+
@test findfirst(iszero, mem) == 4
487+
@test findfirst(==(T(2)), mem) == 2
488+
@test findnext(==(T(2)), mem, 3) == 5
489+
@test findnext(==(T(7)), mem, 4) === nothing
490+
@test findnext(==(T(2)), mem, 7) === nothing
491+
@test_throws BoundsError findnext(iszero, mem, 0)
492+
@test_throws BoundsError findnext(iszero, mem, -3)
493+
494+
@test findlast(iszero, mem) == 4
495+
@test findprev(iszero, mem, 3) === nothing
496+
@test findprev(iszero, mem, 4) == 4
497+
@test findprev(==(T(2)), mem, 5) == 5
498+
@test findprev(==(T(2)), mem, 4) == 2
499+
@test findprev(==(T(9)), mem, 3) === nothing
500+
@test findprev(==(T(2)), mem, -2) === nothing
501+
@test findprev(iszero, mem, 0) === nothing
502+
@test_throws BoundsError findprev(iszero, mem, 7)
503+
end
504+
mem = MemoryView(Int8[2, 3, -1])
505+
@test findfirst(==(0xff), mem) === nothing
506+
@test findprev(==(0xff), mem, 3) === nothing
507+
end
508+
509+
@testset "Find" begin
510+
mem = MemoryView([4, 3, 2])
511+
@test findfirst(==(2), mem) == 3
512+
513+
mem = MemoryView(Int8[6, 2, 7, 0, 2])
514+
@test findfirst(iszero, mem) == 4
515+
@test findfirst(==(Int8(0)), mem) == 4
516+
517+
mem = MemoryView(UInt8[1, 4, 2, 5, 6])
518+
@test findnext(==(0x04), mem, 1) == 2
519+
@test findnext(==(0x04), mem, 3) === nothing
520+
end
521+
end
466522
end
467523

468524
@testset "Iterators.reverse" begin

0 commit comments

Comments
 (0)