Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimised findfirst etc #8

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ authors = ["Jakob Nybo Nissen <[email protected]>"]
[weakdeps]
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"

[extensions]
StringViewsExt = "StringViews"

[compat]
Aqua = "0.8.7"
StringViews = "1"
Test = "1.11"
julia = "1.11"

[extensions]
StringViewsExt = "StringViews"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
StringViews = "354b36f9-a18e-4713-926e-db85100087ba"
Expand Down
114 changes: 103 additions & 11 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ function Base.getindex(v::MemoryView, i::Integer)
@inbounds ref[]
end

function Base.similar(
mem::MemoryView{T1, M},
::Type{T2},
dims::Tuple{Int},
) where {T1, T2, M}
function Base.similar(::MemoryView{T1, M}, ::Type{T2}, dims::Tuple{Int}) where {T1, T2, M}
len = Int(only(dims))::Int
memory = Memory{T2}(undef, len)
MemoryView{T2, M}(unsafe, memoryref(memory), len)
Expand Down Expand Up @@ -89,6 +85,23 @@ end
Base.getindex(v::MemoryView, ::Colon) = v
Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]

function truncate(mem::MemoryView, include_last::Integer)
lst = Int(include_last)::Int
@boundscheck if (lst % UInt) > length(mem) % UInt
throw(BoundsError(mem, lst))
end
typeof(mem)(unsafe, mem.ref, lst)
end

function truncate_start_nonempty(mem::MemoryView, from::Integer)
frm = Int(from)::Int
@boundscheck if ((frm - 1) % UInt) ≥ length(mem) % UInt
throw(BoundsError(mem, frm))
end
newref = @inbounds memoryref(mem.ref, frm)
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
end

function Base.unsafe_copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
iszero(length(src)) && return dst
@inbounds unsafe_copyto!(dst.ref, src.ref, length(src))
Expand All @@ -105,6 +118,17 @@ function Base.copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
unsafe_copyto!(dst, src)
end

# Optimised methods that don't boundscheck
function Base.findnext(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i < 1 && throw(BoundsError(mem, i)))
@inbounds while i <= length(mem)
p(mem[i]) && return i
i += 1
end
nothing
end

# The following two methods could be collapsed, but they aren't for two reasons:
# * To prevent ambiguity with Base
# * Because we DON'T want this code to run with MemoryView{Union{UInt8, Int8}}.
Expand All @@ -126,16 +150,25 @@ function Base.findnext(
_findnext(mem, p.x, start)
end

@inline function _findnext(
function Base.findnext(
::typeof(iszero),
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
i::Integer,
)
_findnext(mem, zero(eltype(mem)), i)
end

Base.@propagate_inbounds function _findnext(
mem::MemoryView{T},
byte::Union{T},
byte::T,
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
real_start = max(start, 1)
v = @inbounds ImmutableMemoryView(mem[real_start:end])
v_ind = @something memchr(v, byte) return nothing
v_ind + real_start - 1
@boundscheck(start < 1 && throw(BoundsError(mem, start)))
start > length(mem) && return nothing
im = @inbounds truncate_start_nonempty(ImmutableMemoryView(mem), start)
v_ind = @something memchr(im, byte) return nothing
v_ind + start - 1
end

function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
Expand All @@ -151,6 +184,65 @@ function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UI
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

function Base.findprev(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i > length(mem) && throw(BoundsError(mem, i)))
@inbounds while i > 0
p(mem[i]) && return i
i -= 1
end
nothing
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, UInt8},
mem::MemoryView{UInt8},
start::Integer,
)
_findprev(mem, p.x, start)
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, Int8},
mem::MemoryView{Int8},
start::Integer,
)
_findprev(mem, p.x, start)
end

function Base.findprev(
::typeof(iszero),
mem::Union{MemoryView{Int8}, MemoryView{UInt8}},
i::Integer,
)
_findprev(mem, zero(eltype(mem)), i)
end

Base.@propagate_inbounds function _findprev(
mem::MemoryView{T},
byte::T,
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
@boundscheck (start > length(mem) && throw(BoundsError(mem, start)))
start < 1 && return nothing
im = @inbounds truncate(ImmutableMemoryView(mem), start)
memrchr(im, byte)
end

function memrchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
isempty(mem) && return nothing
GC.@preserve mem begin
ptr = Ptr{UInt8}(pointer(mem))
p = @ccall memrchr(
ptr::Ptr{UInt8},
(byte % UInt8)::UInt8,
length(mem)::Int,
)::Ptr{Cvoid}
end
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

const Bits =
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char}

Expand Down
82 changes: 69 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,6 @@ end
@test v2 == v1
end

@testset "Find" begin
mem = MemoryView([4, 3, 2])
@test findfirst(==(2), mem) == 3

mem = MemoryView(Int8[6, 2, 7, 0, 2])
@test findfirst(iszero, mem) == 4
@test findfirst(==(Int8(0)), mem) == 4

mem = MemoryView(UInt8[1, 4, 2, 5, 6])
@test findnext(==(0x04), mem, 1) == 2
@test findnext(==(0x04), mem, 3) === nothing
end

@testset "Reverse and reverse!" begin
for v in [
["a", "abc", "a", "c", "kij"],
Expand Down Expand Up @@ -463,6 +450,75 @@ end
@test split_unaligned(v, Val(8)) == split_at(v, 3)
@test split_unaligned(v, Val(16)) == split_at(v, 7)
end

@testset "Find" begin
@testset "Generic find" begin
mem = ImmutableMemoryView([1, 2, 3, 4])
@test findfirst(isodd, mem) == 1
@test findfirst(isodd, mem[2:end]) == 2
@test findfirst(mem[1:0]) === nothing

@test findlast(isodd, mem) == 3
@test findlast(isodd, mem[1:2]) == 1
@test findlast(isodd, mem[1:0]) === nothing

@test findnext(isodd, mem, 0x02) == 3
@test findnext(isodd, mem, 3) == 3
@test findnext(isodd, mem, 0x04) === nothing
@test findnext(isodd, mem, 10) === nothing

@test_throws BoundsError findnext(isodd, mem, 0)
@test_throws BoundsError findnext(isodd, mem, -1)

@test findprev(isodd, mem, 4) == 3
@test findprev(isodd, mem, 0x03) == 3
@test findprev(isodd, mem, 2) == 1
@test findprev(isodd, mem, 0x00) === nothing
@test findprev(isodd, mem, -10) === nothing

@test_throws BoundsError findprev(isodd, mem, 5)
@test_throws BoundsError findprev(isodd, mem, 7)
end

@testset "Memchr routines" begin
for T in Any[Int8, UInt8]
mem = MemoryView(T[6, 2, 7, 0, 2, 1])
@test findfirst(iszero, mem) == 4
@test findfirst(==(T(2)), mem) == 2
@test findnext(==(T(2)), mem, 3) == 5
@test findnext(==(T(7)), mem, 4) === nothing
@test findnext(==(T(2)), mem, 7) === nothing
@test_throws BoundsError findnext(iszero, mem, 0)
@test_throws BoundsError findnext(iszero, mem, -3)

@test findlast(iszero, mem) == 4
@test findprev(iszero, mem, 3) === nothing
@test findprev(iszero, mem, 4) == 4
@test findprev(==(T(2)), mem, 5) == 5
@test findprev(==(T(2)), mem, 4) == 2
@test findprev(==(T(9)), mem, 3) === nothing
@test findprev(==(T(2)), mem, -2) === nothing
@test findprev(iszero, mem, 0) === nothing
@test_throws BoundsError findprev(iszero, mem, 7)
end
mem = MemoryView(Int8[2, 3, -1])
@test findfirst(==(0xff), mem) === nothing
@test findprev(==(0xff), mem, 3) === nothing
end

@testset "Find" begin
mem = MemoryView([4, 3, 2])
@test findfirst(==(2), mem) == 3

mem = MemoryView(Int8[6, 2, 7, 0, 2])
@test findfirst(iszero, mem) == 4
@test findfirst(==(Int8(0)), mem) == 4

mem = MemoryView(UInt8[1, 4, 2, 5, 6])
@test findnext(==(0x04), mem, 1) == 2
@test findnext(==(0x04), mem, 3) === nothing
end
end
end

@testset "Iterators.reverse" begin
Expand Down
Loading