Skip to content

Commit

Permalink
Implement merge for supported methods (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 19, 2024
1 parent 7c54d69 commit 039eeaa
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 19 deletions.
2 changes: 0 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,3 @@ julia> @btime itsample($rng, $iter, $wv, 10^4, algAExpJ);
julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false);
306.039 ms (43 allocations: 370.19 MiB)
```


11 changes: 7 additions & 4 deletions src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ const ordworsample = OrdWORSample()
abstract type AbstractReservoirSample end

# unweighted cases
abstract type AbstractReservoirSampleSingle <: AbstractReservoirSample end
abstract type AbstractReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWorReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWorReservoirSampleMulti <: AbstractWorReservoirSampleMulti end
abstract type AbstractWrReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMulti end

# weighted cases
abstract type AbstractWeightedReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedReservoirSample <: AbstractReservoirSample end
abstract type AbstractWeightedReservoirSampleSingle <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractWeightedReservoirSample end
abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractWeightedReservoirSample end

abstract type ReservoirAlgorithm end

Expand Down
62 changes: 62 additions & 0 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,68 @@ end
is_ordered(s::AbstractOrdWrReservoirSampleMulti) = true
is_ordered(s::AbstractWrReservoirSampleMulti) = false

function Base.merge(s1::AbstractWorReservoirSampleMulti, s2::AbstractWorReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
value = create_new_res_vec(s1, s2, p, len1)
return SampleSingleAlgR(n_tot, s1.rng, value)
end
function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
value = create_new_res_vec(s1, s2, p, len1)
s_merged = SampleMultiAlgRSWRSKIP(n_tot, s1.rng, value)
recompute_skip!(s_merged, len1)
return s_merged
end

function Base.merge!(s1::SampleMultiAlgR, s2::AbstractWorReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
s1 = merge_res_vec!(s1, s2, p, len1, n_tot)
return s1
end
function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
s1 = merge_res_vec!(s1, s2, p, len1, n_tot)
recompute_skip!(s1, len1)
return s1
end
function Base.merge!(s1::SampleSingleAlgL, s2::AbstractWorReservoirSampleMulti)
error("Merging into a ReservoirSample using method algL is not supported")
end

function check_merging_support(s1, s2)
len1, len2 = length(s1.value), length(s2.value)
len1 != len2 && error("Merging samples with different sizes is not supported")
n1, n2 = n_seen(s1), n_seen(s2)
n1 < len1 || n2 < len2 && error("Merging samples with different sizes is not supported")
return len1, len2, n1, n2
end

function create_new_res_vec(s1, s2, p, len1)
value = similar(s1.value)
@inbounds for j in 1:len1
value[j] = rand(s1.rng) < p ? s2.value[j] : s1.value[j]
end
return value
end

function merge_res_vec!(s1, s2, p, len1, n_tot)
@inbounds for j in 1:len1
if rand(s1.rng) < p
s1.value[j] = s2.value[j]
end
end
s1.seen_k = n_tot
return s1
end

function value(s::AbstractWorReservoirSampleMulti)
if n_seen(s) < length(s.value)
return s.value[1:n_seen(s)]
Expand Down
43 changes: 34 additions & 9 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@

mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSample
mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSampleSingle
state::Float64
seen_k::Int
skip_k::Int
const rng::R
value::T
SampleSingleAlgL{T,R}(state, skip_k, rng) where {T,R} = new{T,R}(state, skip_k, rng)
SampleSingleAlgL{T,R}(state, seen_k, skip_k, rng) where {T,R} = new{T,R}(state, seen_k, skip_k, rng)
end

mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSample
state::Int
mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSampleSingle
seen_k::Int
const rng::R
value::T
SampleSingleAlgR{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
SampleSingleAlgR{T,R}(seen_k, rng, value) where {T,R} = new{T,R}(seen_k, rng, value)
SampleSingleAlgR{T,R}(seen_k, rng) where {T,R} = new{T,R}(seen_k, rng)
end

function value(s::SampleSingleAlgL)
s.state === 1.0 && return nothing
return s.value
end
function value(s::SampleSingleAlgR)
s.state === 0 && return nothing
s.seen_k === 0 && return nothing
return s.value
end

function ReservoirSample(T, method::ReservoirAlgorithm = algL)
return ReservoirSample(Random.default_rng(), T, method)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgL = algL)
return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, rng)
return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, 0, rng)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
return SampleSingleAlgR{T, typeof(rng)}(0, rng)
end

@inline function update!(s::SampleSingleAlgR, el)
s.state += 1
if rand(s.rng) <= 1/s.state
s.seen_k += 1
if rand(s.rng) <= 1/s.seen_k
s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgL, el)
s.seen_k += 1
if s.skip_k > 0
s.skip_k -= 1
else
Expand All @@ -51,6 +54,28 @@ end
return s
end

function Base.merge(s1::AbstractReservoirSampleSingle, s2::AbstractReservoirSampleSingle)
n1, n2 = n_seen(s1), n_seen(s2)
n_tot = n1 + n2
value = rand(s1.rng) < n1/n_tot ? s1.value : s2.value
return SampleSingleAlgR{typeof(value), typeof(s1.rng)}(n_tot, s1.rng, value)
end

function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle)
n1, n2 = n_seen(s1), n_seen(s2)
n_tot = n1 + n2
r = rand(s1.rng)
p = n2 / n_tot
if r < p
s1.value = s2.value
end
s1.seen_k = n_tot
return s1
end
function Base.merge!(s1::SampleSingleAlgL, s2::AbstractReservoirSampleSingle)
error("Merging into a ReservoirSample using method algL is not supported")
end

function itsample(iter, method::ReservoirAlgorithm = algL)
return itsample(Random.default_rng(), iter, method)
end
Expand Down
8 changes: 8 additions & 0 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ end
is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true
is_ordered(s::SampleMultiAlgWRSWRSKIP) = false

function Base.merge(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample)
error("Merging is not supported for weighted sampling")
end

function Base.merge!(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample)
error("Merging is not supported for weighted sampling")
end

function value(s::AbstractWeightedWorReservoirSampleMulti)
if n_seen(s) < s.n
return first.(s.value.valtree)[1:n_seen(s)]
Expand Down
24 changes: 21 additions & 3 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@

mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractReservoirSample
mutable struct SampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
const rng::R
value::T
SampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end

mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
const rng::R
value::T
SampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end

function ReservoirSample(rng::R, T, method::AlgARes) where {R<:AbstractRNG}
return SampleSingleAlgARes{T,R}(0.0, rng)
end
function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG}
return SampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end

function value(s::SampleSingleAlgAExpJ)
function value(s::AbstractWeightedReservoirSampleSingle)
s.state === 0.0 && return nothing
return s.value
end

@inline function update!(s::SampleSingleAlgARes, el, w)
priority = -randexp(s.rng)/w
if priority > s.state
s.state = priority
s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgAExpJ, el, weight)
s.state += weight
if s.skip_w <= s.state
Expand All @@ -30,7 +48,7 @@ function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ)
end

function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ)
s = ReservoirSample(rng, Base.@default_eltype(iter), algAExpJ)
s = ReservoirSample(rng, calculate_eltype(iter), algAExpJ)
for x in iter
update!(s, x, wv(x))
end
Expand Down
2 changes: 1 addition & 1 deletion test/weighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

@testset "Weighted sampling single tests" begin
@testset "method=$method" for method in (:(),)
@testset "method=$method" for method in (algARes, algAExpJ)
wv(el) = 1.0
a, b = 1, 100
z = itsample(a:b, wv)
Expand Down

0 comments on commit 039eeaa

Please sign in to comment.