From 294e8ad240a356b05bac34c6b19811e0761d3489 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Sat, 27 Apr 2024 02:18:04 +0200 Subject: [PATCH] Improve performance of weighted sampling single (#72) --- src/WeightedSamplingSingle.jl | 35 ++++++++++++++++++++++++----------- test/benchmark_tests.jl | 16 ++++++++-------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index 9e15fff..f6627db 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -1,10 +1,14 @@ +mutable struct RefVal{T} + value::T + RefVal{T}() where T = new{T}() + RefVal(value::T) where T = new{T}(value) +end + struct ImmutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 rng::R - value::T - ImmutSampleSingleAlgARes(state, rng::R, value::T) where {T,R} = new{T,R}(state, rng, value) - ImmutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng) + rvalue::RefVal{T} end mutable struct MutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 @@ -18,9 +22,7 @@ struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 skip_w::Float64 rng::R - value::T - ImmutSampleSingleAlgAExpJ(state, skip_w, rng::R, value::T) where {T,R} = new{T,R}(state, skip_w, rng, value) - ImmutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng) + rvalue::RefVal{T} end mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 @@ -35,37 +37,48 @@ function ReservoirSample(rng::R, T, ::AlgARes, ::MutSample) where {R<:AbstractRN return MutSampleSingleAlgARes{T,R}(typemax(Float64), rng) end function ReservoirSample(rng::R, T, ::AlgARes, ::ImmutSample) where {R<:AbstractRNG} - return ImmutSampleSingleAlgARes{T,R}(typemax(Float64), rng) + return ImmutSampleSingleAlgARes(typemax(Float64), rng, RefVal{T}()) end function ReservoirSample(rng::R, T, ::AlgAExpJ, ::MutSample) where {R<:AbstractRNG} return MutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng) end function ReservoirSample(rng::R, T, ::AlgAExpJ, ::ImmutSample) where {R<:AbstractRNG} - return ImmutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng) + return ImmutSampleSingleAlgAExpJ(0.0, 0.0, rng, RefVal{T}()) end function value(s::AbstractWeightedReservoirSampleSingle) s.state === 0.0 && return nothing - return s.value + return get_val(s) end @inline function update!(s::SampleSingleAlgARes, el, w) priority = randexp(s.rng)/w if priority < s.state @imm_reset s.state = priority - @imm_reset s.value = el + s = set_val(s, el) end return s end @inline function update!(s::SampleSingleAlgAExpJ, el, weight) @imm_reset s.state += weight if s.skip_w <= s.state - @imm_reset s.value = el @imm_reset s.skip_w = s.state/rand(s.rng) + s = set_val(s, el) end return s end +get_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}) = s.rvalue.value +function set_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}, el) + @reset s.rvalue.value = el + return s +end +get_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}) = s.value +function set_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}, el) + s.value = el + return s +end + function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ; iter_type = infer_eltype(iter)) return itsample(Random.default_rng(), iter, wv, method) diff --git a/test/benchmark_tests.jl b/test/benchmark_tests.jl index 152f9fc..e852310 100644 --- a/test/benchmark_tests.jl +++ b/test/benchmark_tests.jl @@ -3,22 +3,22 @@ iter = Iterators.filter(x -> x != 10, 1:10^4) wv(el) = 1.0 for m in (algR, algL, algRSWRSKIP) - for size in (1, 10) - size == 1 && m === algRSWRSKIP && continue - s = size == 1 ? () : (10,) + for size in (nothing, 10) + size == nothing && m === algRSWRSKIP && continue + s = size == nothing ? () : (size,) b = @benchmark itsample($rng, $iter, $s..., $m) evals=1 - mstr = "$m $(size == 1 ? :single : :multi)" + mstr = "$m $(size == nothing ? :single : :multi)" print(mstr * repeat(" ", 35-length(mstr))) print(" --> Time: $(@sprintf("%.2f", median(b.times)*1e-3)) μs |") println(" Memory: $(b.memory) bytes") end end for m in (algARes, algAExpJ, algWRSWRSKIP) - for size in (1, 10) - size == 1 && m === algWRSWRSKIP && continue - s = size == 1 ? () : (10,) + for size in (nothing, 10) + size == nothing && m === algWRSWRSKIP && continue + s = size == nothing ? () : (size,) b = @benchmark itsample($rng, $iter, $wv, $s..., $m) evals=1 - mstr = "$m $(size == 1 ? :single : :multi)" + mstr = "$m $(size == nothing ? :single : :multi)" print(mstr * repeat(" ", 35-length(mstr))) print(" --> Time: $(@sprintf("%.2f", median(b.times)*1e-3)) μs |") println(" Memory: $(b.memory) bytes")