Skip to content

Commit

Permalink
simplify reservoir multi code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Jan 13, 2024
1 parent c001e2c commit 631a66d
Showing 1 changed file with 46 additions and 78 deletions.
124 changes: 46 additions & 78 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function reservoir_sample(rng, iter, n; replace = false, ordered = false)
end
end

function reservoir_sample(rng, iter, n::Int, ::WORSample)
function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample})
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
Expand All @@ -38,53 +38,26 @@ function reservoir_sample(rng, iter, n::Int, ::WORSample)
reservoir[1] = el
@inbounds for i in 2:n
it = iterate(iter, state)
isnothing(it) && return shuffle!(rng, resize!(reservoir, i-1))
isnothing(it) && return transform(rng, resize!(reservoir, i-1), nothing, is)
el, state = it
reservoir[i] = el
end
u = randexp(rng)
k, order = instantiate_order(n, is)
@inbounds while true
w = exp(-u/n)
skip_k = -ceil(Int, randexp(rng)/log(1-w))
it = skip_ahead_unknown_end(iter, state, skip_k)
isnothing(it) && return shuffle!(rng, reservoir)
el, state = it
reservoir[rand(rng, 1:n)] = el
u += randexp(rng)
end
end

function reservoir_sample(rng, iter, n::Int, ::OrdWORSample)
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
reservoir = Vector{iter_type}(undef, n)
reservoir[1] = el
@inbounds for i in 2:n
it = iterate(iter, state)
isnothing(it) && return resize!(reservoir, i-1)
el, state = it
reservoir[i] = el
end
u = randexp(rng)
o = [i for i in 1:n]
k = n
@inbounds while true
w = exp(-u/n)
skip_k = -ceil(Int, randexp(rng)/log(1-w))
it = skip_ahead_unknown_end(iter, state, skip_k)
isnothing(it) && return reservoir[sortperm(o)]
isnothing(it) && return transform(rng, reservoir, order, is)
el, state = it
q = rand(rng, 1:n)
reservoir[q] = el
k += skip_k + 1
o[q] = k
reservoir[q] = el
k = update_order!(k, skip_k, q, order, is)
u += randexp(rng)
end
end

function reservoir_sample(rng, iter, n::Int, ::WRSample)
function reservoir_sample(rng, iter, n::Int, is::Union{WRSample, OrdWRSample})
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
Expand All @@ -93,16 +66,18 @@ function reservoir_sample(rng, iter, n::Int, ::WRSample)
reservoir[1] = el
@inbounds for i in 2:n
it = iterate(iter, state)
isnothing(it) && return sample(rng, resize!(reservoir, i-1), n)
isnothing(it) && return sample(rng, resize!(reservoir, i-1), n,
ordered=is isa WRSample ? false : true)
el, state = it
reservoir[i] = el
end
reservoir = sample(rng, reservoir, n)
i, order = instantiate_order(n, is)
i = n
reservoir = sample(rng, reservoir, n, ordered=is isa WRSample ? false : true)
@inbounds while true
skip_k = skip(rng, i, n)
it = skip_ahead_unknown_end(iter, state, skip_k)
isnothing(it) && return shuffle!(rng, reservoir)
isnothing(it) && return transform(rng, reservoir, order, is)
el, state = it
i += skip_k + 1
p = 1/i
Expand All @@ -112,52 +87,13 @@ function reservoir_sample(rng, iter, n::Int, ::WRSample)
if k == 1
r = rand(rng, 1:n)
reservoir[r] = el
update_order_single!(i, r, order, is)
else
for j in 1:k
r = rand(rng, j:n)
reservoir[r] = el
reservoir[r], reservoir[j] = reservoir[j], reservoir[r]
end
end
end
end

function reservoir_sample(rng, iter, n::Int, ::OrdWRSample)
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
reservoir = Vector{iter_type}(undef, n)
el, state = it
reservoir[1] = el
@inbounds for i in 2:n
it = iterate(iter, state)
isnothing(it) && return sample(rng, resize!(reservoir, i-1), n, ordered=true)
el, state = it
reservoir[i] = el
end
o = [i for i in 1:n]
reservoir = sample(rng, reservoir, n, ordered=true)
i = n
@inbounds while true
skip_k = skip(rng, i, n)
it = skip_ahead_unknown_end(iter, state, skip_k)
isnothing(it) && return reservoir[sortperm(o)]
el, state = it
i += skip_k + 1
p = 1/i
z = (1-p)^(n-3)
q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1))
k = choose(n, p, q, z)
if k == 1
r = rand(rng, 1:n)
reservoir[r] = el
o[r] = i
else
for j in 1:k
r = rand(rng, j:n)
reservoir[r] = el
reservoir[r], reservoir[j] = reservoir[j], reservoir[r]
o[r], o[j] = o[j], i
update_order_multi!(i, r, j, order, is)
end
end
end
Expand Down Expand Up @@ -253,3 +189,35 @@ function skip_ahead_unknown_end(iter, state, n)
isnothing(it) && return nothing
return it
end

instantiate_order(n, ::Union{WORSample, WRSample}) = nothing, nothing
function instantiate_order(n, ::Union{OrdWORSample, OrdWRSample})
return n, [i for i in 1:n]
end

update_order!(k, skip_k, q, order, ::WORSample) = nothing
function update_order!(k, skip_k, q, order, ::OrdWORSample)
k += skip_k + 1
order[q] = k
return k
end

update_order_single!(k, r, order, ::WRSample) = nothing
function update_order_single!(k, r, order, ::OrdWRSample)
order[r] = k
end

update_order_multi!(k, r, j, order, ::WRSample) = nothing
function update_order_multi!(k, r, j, order, ::OrdWRSample)
order[r], order[j] = order[j], k
end

function transform(rng, reservoir, order, ::Union{WORSample, WRSample})
return shuffle!(rng, reservoir)
end
function transform(rng, reservoir, order, ::Union{OrdWORSample, OrdWRSample})
return reservoir[sortperm(order)]
end
function transform(rng, reservoir, order::Nothing, ::Union{OrdWORSample, OrdWRSample})
return reservoir
end

0 comments on commit 631a66d

Please sign in to comment.