Skip to content

Commit

Permalink
Improve performance of weighted sampling single (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 27, 2024
1 parent 6cee15a commit 294e8ad
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
35 changes: 24 additions & 11 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@

mutable struct RefVal{T}
value::T
RefVal{T}() where T = new{T}()
RefVal(value::T) where T = new{T}(value)
end

struct ImmutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
rng::R
value::T
ImmutSampleSingleAlgARes(state, rng::R, value::T) where {T,R} = new{T,R}(state, rng, value)
ImmutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
rvalue::RefVal{T}
end
mutable struct MutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
Expand All @@ -18,9 +22,7 @@ struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
rng::R
value::T
ImmutSampleSingleAlgAExpJ(state, skip_w, rng::R, value::T) where {T,R} = new{T,R}(state, skip_w, rng, value)
ImmutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
rvalue::RefVal{T}
end
mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
Expand All @@ -35,37 +37,48 @@ function ReservoirSample(rng::R, T, ::AlgARes, ::MutSample) where {R<:AbstractRN
return MutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
end
function ReservoirSample(rng::R, T, ::AlgARes, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
return ImmutSampleSingleAlgARes(typemax(Float64), rng, RefVal{T}())
end
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::MutSample) where {R<:AbstractRNG}
return MutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
return ImmutSampleSingleAlgAExpJ(0.0, 0.0, rng, RefVal{T}())
end

function value(s::AbstractWeightedReservoirSampleSingle)
s.state === 0.0 && return nothing
return s.value
return get_val(s)
end

@inline function update!(s::SampleSingleAlgARes, el, w)
priority = randexp(s.rng)/w
if priority < s.state
@imm_reset s.state = priority
@imm_reset s.value = el
s = set_val(s, el)
end
return s
end
@inline function update!(s::SampleSingleAlgAExpJ, el, weight)
@imm_reset s.state += weight
if s.skip_w <= s.state
@imm_reset s.value = el
@imm_reset s.skip_w = s.state/rand(s.rng)
s = set_val(s, el)
end
return s
end

get_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}) = s.rvalue.value
function set_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}, el)
@reset s.rvalue.value = el
return s
end
get_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}) = s.value
function set_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}, el)
s.value = el
return s
end

function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, wv, method)
Expand Down
16 changes: 8 additions & 8 deletions test/benchmark_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
iter = Iterators.filter(x -> x != 10, 1:10^4)
wv(el) = 1.0
for m in (algR, algL, algRSWRSKIP)
for size in (1, 10)
size == 1 && m === algRSWRSKIP && continue
s = size == 1 ? () : (10,)
for size in (nothing, 10)
size == nothing && m === algRSWRSKIP && continue
s = size == nothing ? () : (size,)
b = @benchmark itsample($rng, $iter, $s..., $m) evals=1
mstr = "$m $(size == 1 ? :single : :multi)"
mstr = "$m $(size == nothing ? :single : :multi)"
print(mstr * repeat(" ", 35-length(mstr)))
print(" --> Time: $(@sprintf("%.2f", median(b.times)*1e-3)) μs |")
println(" Memory: $(b.memory) bytes")
end
end
for m in (algARes, algAExpJ, algWRSWRSKIP)
for size in (1, 10)
size == 1 && m === algWRSWRSKIP && continue
s = size == 1 ? () : (10,)
for size in (nothing, 10)
size == nothing && m === algWRSWRSKIP && continue
s = size == nothing ? () : (size,)
b = @benchmark itsample($rng, $iter, $wv, $s..., $m) evals=1
mstr = "$m $(size == 1 ? :single : :multi)"
mstr = "$m $(size == nothing ? :single : :multi)"
print(mstr * repeat(" ", 35-length(mstr)))
print(" --> Time: $(@sprintf("%.2f", median(b.times)*1e-3)) μs |")
println(" Memory: $(b.memory) bytes")
Expand Down

0 comments on commit 294e8ad

Please sign in to comment.