From 039eeaaf3184db48260633839337b7f7d81a3a90 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Fri, 19 Apr 2024 18:51:05 +0200 Subject: [PATCH] Implement merge for supported methods (#65) --- docs/src/index.md | 2 - src/StreamSampling.jl | 11 +++-- src/UnweightedSamplingMulti.jl | 62 ++++++++++++++++++++++++++ src/UnweightedSamplingSingle.jl | 43 ++++++++++++++---- src/WeightedSamplingMulti.jl | 8 ++++ src/WeightedSamplingSingle.jl | 24 ++++++++-- test/weighted_sampling_single_tests.jl | 2 +- 7 files changed, 133 insertions(+), 19 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index c542293..952405d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -91,5 +91,3 @@ julia> @btime itsample($rng, $iter, $wv, 10^4, algAExpJ); julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false); 306.039 ms (43 allocations: 370.19 MiB) ``` - - diff --git a/src/StreamSampling.jl b/src/StreamSampling.jl index 68a0e89..a025b1a 100644 --- a/src/StreamSampling.jl +++ b/src/StreamSampling.jl @@ -18,6 +18,7 @@ const ordworsample = OrdWORSample() abstract type AbstractReservoirSample end # unweighted cases +abstract type AbstractReservoirSampleSingle <: AbstractReservoirSample end abstract type AbstractReservoirSampleMulti <: AbstractReservoirSample end abstract type AbstractWorReservoirSampleMulti <: AbstractReservoirSampleMulti end abstract type AbstractOrdWorReservoirSampleMulti <: AbstractWorReservoirSampleMulti end @@ -25,10 +26,12 @@ abstract type AbstractWrReservoirSampleMulti <: AbstractReservoirSampleMulti end abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMulti end # weighted cases -abstract type AbstractWeightedReservoirSampleMulti <: AbstractReservoirSample end -abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractReservoirSample end -abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractReservoirSample end -abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractReservoirSample end +abstract type AbstractWeightedReservoirSample <: AbstractReservoirSample end +abstract type AbstractWeightedReservoirSampleSingle <: AbstractWeightedReservoirSample end +abstract type AbstractWeightedReservoirSampleMulti <: AbstractWeightedReservoirSample end +abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractWeightedReservoirSample end +abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractWeightedReservoirSample end +abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractWeightedReservoirSample end abstract type ReservoirAlgorithm end diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index d42a13b..766b1fd 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -178,6 +178,68 @@ end is_ordered(s::AbstractOrdWrReservoirSampleMulti) = true is_ordered(s::AbstractWrReservoirSampleMulti) = false +function Base.merge(s1::AbstractWorReservoirSampleMulti, s2::AbstractWorReservoirSampleMulti) + len1, len2, n1, n2 = check_merging_support(s1, s2) + n_tot = n1 + n2 + p = n2 / n_tot + value = create_new_res_vec(s1, s2, p, len1) + return SampleSingleAlgR(n_tot, s1.rng, value) +end +function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti) + len1, len2, n1, n2 = check_merging_support(s1, s2) + n_tot = n1 + n2 + p = n2 / n_tot + value = create_new_res_vec(s1, s2, p, len1) + s_merged = SampleMultiAlgRSWRSKIP(n_tot, s1.rng, value) + recompute_skip!(s_merged, len1) + return s_merged +end + +function Base.merge!(s1::SampleMultiAlgR, s2::AbstractWorReservoirSampleMulti) + len1, len2, n1, n2 = check_merging_support(s1, s2) + n_tot = n1 + n2 + p = n2 / n_tot + s1 = merge_res_vec!(s1, s2, p, len1, n_tot) + return s1 +end +function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti) + len1, len2, n1, n2 = check_merging_support(s1, s2) + n_tot = n1 + n2 + p = n2 / n_tot + s1 = merge_res_vec!(s1, s2, p, len1, n_tot) + recompute_skip!(s1, len1) + return s1 +end +function Base.merge!(s1::SampleSingleAlgL, s2::AbstractWorReservoirSampleMulti) + error("Merging into a ReservoirSample using method algL is not supported") +end + +function check_merging_support(s1, s2) + len1, len2 = length(s1.value), length(s2.value) + len1 != len2 && error("Merging samples with different sizes is not supported") + n1, n2 = n_seen(s1), n_seen(s2) + n1 < len1 || n2 < len2 && error("Merging samples with different sizes is not supported") + return len1, len2, n1, n2 +end + +function create_new_res_vec(s1, s2, p, len1) + value = similar(s1.value) + @inbounds for j in 1:len1 + value[j] = rand(s1.rng) < p ? s2.value[j] : s1.value[j] + end + return value +end + +function merge_res_vec!(s1, s2, p, len1, n_tot) + @inbounds for j in 1:len1 + if rand(s1.rng) < p + s1.value[j] = s2.value[j] + end + end + s1.seen_k = n_tot + return s1 +end + function value(s::AbstractWorReservoirSampleMulti) if n_seen(s) < length(s.value) return s.value[1:n_seen(s)] diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index 6d550a0..edb6f0d 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -1,17 +1,19 @@ -mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSample +mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSampleSingle state::Float64 + seen_k::Int skip_k::Int const rng::R value::T - SampleSingleAlgL{T,R}(state, skip_k, rng) where {T,R} = new{T,R}(state, skip_k, rng) + SampleSingleAlgL{T,R}(state, seen_k, skip_k, rng) where {T,R} = new{T,R}(state, seen_k, skip_k, rng) end -mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSample - state::Int +mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSampleSingle + seen_k::Int const rng::R value::T - SampleSingleAlgR{T,R}(state, rng) where {T,R} = new{T,R}(state, rng) + SampleSingleAlgR{T,R}(seen_k, rng, value) where {T,R} = new{T,R}(seen_k, rng, value) + SampleSingleAlgR{T,R}(seen_k, rng) where {T,R} = new{T,R}(seen_k, rng) end function value(s::SampleSingleAlgL) @@ -19,7 +21,7 @@ function value(s::SampleSingleAlgL) return s.value end function value(s::SampleSingleAlgR) - s.state === 0 && return nothing + s.seen_k === 0 && return nothing return s.value end @@ -27,20 +29,21 @@ function ReservoirSample(T, method::ReservoirAlgorithm = algL) return ReservoirSample(Random.default_rng(), T, method) end function ReservoirSample(rng::AbstractRNG, T, method::AlgL = algL) - return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, rng) + return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, 0, rng) end function ReservoirSample(rng::AbstractRNG, T, method::AlgR) return SampleSingleAlgR{T, typeof(rng)}(0, rng) end @inline function update!(s::SampleSingleAlgR, el) - s.state += 1 - if rand(s.rng) <= 1/s.state + s.seen_k += 1 + if rand(s.rng) <= 1/s.seen_k s.value = el end return s end @inline function update!(s::SampleSingleAlgL, el) + s.seen_k += 1 if s.skip_k > 0 s.skip_k -= 1 else @@ -51,6 +54,28 @@ end return s end +function Base.merge(s1::AbstractReservoirSampleSingle, s2::AbstractReservoirSampleSingle) + n1, n2 = n_seen(s1), n_seen(s2) + n_tot = n1 + n2 + value = rand(s1.rng) < n1/n_tot ? s1.value : s2.value + return SampleSingleAlgR{typeof(value), typeof(s1.rng)}(n_tot, s1.rng, value) +end + +function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle) + n1, n2 = n_seen(s1), n_seen(s2) + n_tot = n1 + n2 + r = rand(s1.rng) + p = n2 / n_tot + if r < p + s1.value = s2.value + end + s1.seen_k = n_tot + return s1 +end +function Base.merge!(s1::SampleSingleAlgL, s2::AbstractReservoirSampleSingle) + error("Merging into a ReservoirSample using method algL is not supported") +end + function itsample(iter, method::ReservoirAlgorithm = algL) return itsample(Random.default_rng(), iter, method) end diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 1d22947..b2174df 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -160,6 +160,14 @@ end is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true is_ordered(s::SampleMultiAlgWRSWRSKIP) = false +function Base.merge(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample) + error("Merging is not supported for weighted sampling") +end + +function Base.merge!(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample) + error("Merging is not supported for weighted sampling") +end + function value(s::AbstractWeightedWorReservoirSampleMulti) if n_seen(s) < s.n return first.(s.value.valtree)[1:n_seen(s)] diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index 9dcbd77..b3f2929 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -1,5 +1,12 @@ -mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractReservoirSample +mutable struct SampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle + state::Float64 + const rng::R + value::T + SampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng) +end + +mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 skip_w::Float64 const rng::R @@ -7,15 +14,26 @@ mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractReservoirSample SampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng) end +function ReservoirSample(rng::R, T, method::AlgARes) where {R<:AbstractRNG} + return SampleSingleAlgARes{T,R}(0.0, rng) +end function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG} return SampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng) end -function value(s::SampleSingleAlgAExpJ) +function value(s::AbstractWeightedReservoirSampleSingle) s.state === 0.0 && return nothing return s.value end +@inline function update!(s::SampleSingleAlgARes, el, w) + priority = -randexp(s.rng)/w + if priority > s.state + s.state = priority + s.value = el + end + return s +end @inline function update!(s::SampleSingleAlgAExpJ, el, weight) s.state += weight if s.skip_w <= s.state @@ -30,7 +48,7 @@ function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ) end function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ) - s = ReservoirSample(rng, Base.@default_eltype(iter), algAExpJ) + s = ReservoirSample(rng, calculate_eltype(iter), algAExpJ) for x in iter update!(s, x, wv(x)) end diff --git a/test/weighted_sampling_single_tests.jl b/test/weighted_sampling_single_tests.jl index 07d3377..3506863 100644 --- a/test/weighted_sampling_single_tests.jl +++ b/test/weighted_sampling_single_tests.jl @@ -1,6 +1,6 @@ @testset "Weighted sampling single tests" begin - @testset "method=$method" for method in (:(),) + @testset "method=$method" for method in (algARes, algAExpJ) wv(el) = 1.0 a, b = 1, 100 z = itsample(a:b, wv)