Skip to content

Commit

Permalink
Implement ordered version of weighted sampling (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 20, 2024
1 parent 5539bed commit 01c46b6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
62 changes: 48 additions & 14 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ mutable struct SampleMultiAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMul
const value::BH
end

mutable struct SampleMultiOrdAlgARes{BH,R} <: AbstractWeightedWorReservoirSampleMulti
seen_k::Int
n::Int
const rng::R
const value::BH
end

mutable struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
state::Float64
min_priority::Float64
Expand All @@ -15,6 +22,15 @@ mutable struct SampleMultiAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMu
const value::BH
end

mutable struct SampleMultiOrdAlgAExpJ{BH,R} <: AbstractWeightedWorReservoirSampleMulti
state::Float64
min_priority::Float64
seen_k::Int
const n::Int
const rng::R
const value::BH
end

mutable struct SampleMultiAlgWRSWRSKIP{T,R} <: AbstractWeightedWrReservoirSampleMulti
state::Float64
skip_w::Float64
Expand All @@ -35,20 +51,24 @@ mutable struct SampleMultiOrdAlgWRSWRSKIP{T,R} <: AbstractWeightedWrReservoirSam
end

function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgAExpJ; ordered = false)
value = BinaryHeap(Base.By(last), Pair{T, Float64}[])
sizehint!(value, n)
if ordered
error("Not implemented yet")
value = BinaryHeap(Base.By(last), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiOrdAlgAExpJ(0.0, 0.0, 0, n, rng, value)
else
value = BinaryHeap(Base.By(last), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgAExpJ(0.0, 0.0, 0, n, rng, value)
end
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgARes; ordered = false)
value = BinaryHeap(Base.By(last), Pair{T, Float64}[])
sizehint!(value, n)
if ordered
error("Not implemented yet")
value = BinaryHeap(Base.By(last), Tuple{T, Int, Float64}[])
sizehint!(value, n)
return SampleMultiOrdAlgARes(0, n, rng, value)
else
value = BinaryHeap(Base.By(last), Pair{T, Float64}[])
sizehint!(value, n)
return SampleMultiAlgARes(0, n, rng, value)
end
end
Expand All @@ -63,33 +83,33 @@ function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP;
end
end

@inline function update!(s::SampleMultiAlgARes, el, w)
@inline function update!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w)
n = s.n
s.seen_k += 1
priority = -randexp(s.rng)/w
if s.seen_k <= n
push!(s.value, el => priority)
push_value!(s, el, priority)
else
min_priority = last(first(s.value))
if priority > min_priority
pop!(s.value)
push!(s.value, el => priority)
push_value!(s, el, priority)
end
end
return s
end
@inline function update!(s::SampleMultiAlgAExpJ, el, w)
@inline function update!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w)
n = s.n
s.seen_k += 1
s.state -= w
if s.seen_k <= n
priority = exp(-randexp(s.rng)/w)
push!(s.value, el => priority)
push_value!(s, el, priority)
s.seen_k == n && @inline recompute_skip!(s)
elseif s.state <= 0.0
priority = @inline compute_skip_priority(s, w)
pop!(s.value)
push!(s.value, el => priority)
push_value!(s, el, priority)
@inline recompute_skip!(s)
end
return s
Expand Down Expand Up @@ -138,7 +158,7 @@ function compute_skip_priority(s, w)
return exp(log(rand(s.rng, Uniform(t,1)))/w)
end

function recompute_skip!(s::SampleMultiAlgAExpJ)
function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ})
s.min_priority = last(first(s.value))
s.state = -randexp(s.rng)/log(s.min_priority)
end
Expand All @@ -147,6 +167,12 @@ function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSW
s.skip_w = s.state/q
end

function push_value!(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}, el, priority)
push!(s.value, el => priority)
end
function push_value!(s::Union{SampleMultiOrdAlgARes, SampleMultiOrdAlgAExpJ}, el, priority)
push!(s.value, (el, s.seen_k, priority))
end
update_order_single!(s::SampleMultiAlgWRSWRSKIP, r) = nothing
function update_order_single!(s::SampleMultiOrdAlgWRSWRSKIP, r)
s.ord[r] = n_seen(s)
Expand All @@ -162,7 +188,7 @@ is_ordered(s::SampleMultiAlgWRSWRSKIP) = false

function value(s::AbstractWeightedWorReservoirSampleMulti)
if n_seen(s) < s.n
return first.(s.value.valtree)[1:n_seen(s)]
return first.(s.value.valtree[1:n_seen(s)])
else
return first.(s.value.valtree)
end
Expand All @@ -175,6 +201,14 @@ function value(s::AbstractWeightedWrReservoirSampleMulti)
end
end

function ordered_value(s::Union{SampleMultiOrdAlgARes, SampleMultiOrdAlgAExpJ})
if n_seen(s) < length(s.value)
vals = s.value.valtree[1:n_seen(s)]
else
vals = s.value.valtree
end
return first.(vals[sortperm(map(x -> x[2], vals))])
end
function ordered_value(s::SampleMultiOrdAlgWRSWRSKIP)
if n_seen(s) < length(s.value)
return sample(s.rng, s.value[1:n_seen(s)], weights(s.weights[1:n_seen(s)]), length(s.value); ordered=true)
Expand Down
3 changes: 1 addition & 2 deletions test/weighted_sampling_multi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ function prob_no_replace(k)
end

@testset "Weighted sampling multi tests" begin
combs = vec(collect(Iterators.product([(algAExpJ, algARes, algWRSWRSKIP), (false, )]...)))
push!(combs, (algWRSWRSKIP, true))
combs = Iterators.product([(algAExpJ, algARes, algWRSWRSKIP), (false, true)]...)
@testset "method=$method ordered=$ordered" for (method, ordered) in combs
a, b = 1, 10
# test return values of iter with known lengths are inrange
Expand Down

0 comments on commit 01c46b6

Please sign in to comment.