Skip to content

Commit

Permalink
add ordered reservoir sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Dec 27, 2023
1 parent 0dbccbd commit 2757c0c
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,50 @@ function itsample(rng::AbstractRNG, iter, n::Int;
if replace
error("Not implemented yet")
else
unweighted_resorvoir_sampling(rng, iter, n)
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
end
else
unweighted_resorvoir_sampling(rng, iter, n)
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
#double_scan_sampling(rng, iter, n, replace, ordered)
end
else
unweighted_resorvoir_sampling(rng, iter, n)
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
#single_scan_sampling(rng, iter, n, replace, ordered)
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int)
function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false})
iter_type = eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
reservoir = Vector{iter_type}(undef, n)
reservoir[1] = el
for i in 2:n
it = iterate(iter, state)
isnothing(it) && return shuffle!(reservoir[1:i-1])
el, state = it
@inbounds reservoir[i] = el
end
u = randexp(rng)
while true
w = exp(-u/n)
skip_counter = ceil(Int, randexp(rng)/log(1-w))
while skip_counter != 0
skip_res = iterate(iter, state)
isnothing(skip_res) && return shuffle!(reservoir)
state = skip_res[2]
skip_counter += 1
end
it = iterate(iter, state)
isnothing(it) && return shuffle!(reservoir)
el, state = it
@inbounds reservoir[rand(rng, 1:n)] = el
u += randexp(rng)
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true})
iter_type = eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
Expand All @@ -39,19 +70,25 @@ function unweighted_resorvoir_sampling(rng, iter, n::Int)
@inbounds reservoir[i] = el
end
u = randexp(rng)
o = [i for i in 1:n]
k = n
while true
w = exp(-u/n)
skip_counter = ceil(Int, randexp(rng)/log(1-w))
k += -skip_counter
while skip_counter != 0
skip_res = iterate(iter, state)
isnothing(skip_res) && return reservoir
isnothing(skip_res) && return reservoir[sortperm(o)]
state = skip_res[2]
skip_counter += 1
end
it = iterate(iter, state)
isnothing(it) && return reservoir
k += 1
isnothing(it) && return reservoir[sortperm(o)]
el, state = it
reservoir[rand(rng, 1:n)] = el
v = rand(rng, 1:n)
@inbounds reservoir[v] = el
@inbounds o[v] = k
u += randexp(rng)
end
end
Expand Down

0 comments on commit 2757c0c

Please sign in to comment.