diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 186b92c..c660ce7 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -29,7 +29,7 @@ function reservoir_sample(rng, iter, n; replace = false, ordered = false) end end -function reservoir_sample(rng, iter, n::Int, ::WORSample) +function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample}) iter_type = Base.@default_eltype(iter) it = iterate(iter) isnothing(it) && return iter_type[] @@ -38,53 +38,26 @@ function reservoir_sample(rng, iter, n::Int, ::WORSample) reservoir[1] = el @inbounds for i in 2:n it = iterate(iter, state) - isnothing(it) && return shuffle!(rng, resize!(reservoir, i-1)) + isnothing(it) && return transform(rng, resize!(reservoir, i-1), nothing, is) el, state = it reservoir[i] = el end u = randexp(rng) + k, order = instantiate_order(n, is) @inbounds while true w = exp(-u/n) skip_k = -ceil(Int, randexp(rng)/log(1-w)) it = skip_ahead_unknown_end(iter, state, skip_k) - isnothing(it) && return shuffle!(rng, reservoir) - el, state = it - reservoir[rand(rng, 1:n)] = el - u += randexp(rng) - end -end - -function reservoir_sample(rng, iter, n::Int, ::OrdWORSample) - iter_type = Base.@default_eltype(iter) - it = iterate(iter) - isnothing(it) && return iter_type[] - el, state = it - reservoir = Vector{iter_type}(undef, n) - reservoir[1] = el - @inbounds for i in 2:n - it = iterate(iter, state) - isnothing(it) && return resize!(reservoir, i-1) - el, state = it - reservoir[i] = el - end - u = randexp(rng) - o = [i for i in 1:n] - k = n - @inbounds while true - w = exp(-u/n) - skip_k = -ceil(Int, randexp(rng)/log(1-w)) - it = skip_ahead_unknown_end(iter, state, skip_k) - isnothing(it) && return reservoir[sortperm(o)] + isnothing(it) && return transform(rng, reservoir, order, is) el, state = it q = rand(rng, 1:n) - reservoir[q] = el - k += skip_k + 1 - o[q] = k + reservoir[q] = el + k = update_order!(k, skip_k, q, order, is) u += randexp(rng) end end -function reservoir_sample(rng, iter, n::Int, ::WRSample) +function reservoir_sample(rng, iter, n::Int, is::Union{WRSample, OrdWRSample}) iter_type = Base.@default_eltype(iter) it = iterate(iter) isnothing(it) && return iter_type[] @@ -93,16 +66,18 @@ function reservoir_sample(rng, iter, n::Int, ::WRSample) reservoir[1] = el @inbounds for i in 2:n it = iterate(iter, state) - isnothing(it) && return sample(rng, resize!(reservoir, i-1), n) + isnothing(it) && return sample(rng, resize!(reservoir, i-1), n, + ordered=is isa WRSample ? false : true) el, state = it reservoir[i] = el end - reservoir = sample(rng, reservoir, n) + i, order = instantiate_order(n, is) i = n + reservoir = sample(rng, reservoir, n, ordered=is isa WRSample ? false : true) @inbounds while true skip_k = skip(rng, i, n) it = skip_ahead_unknown_end(iter, state, skip_k) - isnothing(it) && return shuffle!(rng, reservoir) + isnothing(it) && return transform(rng, reservoir, order, is) el, state = it i += skip_k + 1 p = 1/i @@ -112,52 +87,13 @@ function reservoir_sample(rng, iter, n::Int, ::WRSample) if k == 1 r = rand(rng, 1:n) reservoir[r] = el + update_order_single!(i, r, order, is) else for j in 1:k r = rand(rng, j:n) reservoir[r] = el reservoir[r], reservoir[j] = reservoir[j], reservoir[r] - end - end - end -end - -function reservoir_sample(rng, iter, n::Int, ::OrdWRSample) - iter_type = Base.@default_eltype(iter) - it = iterate(iter) - isnothing(it) && return iter_type[] - reservoir = Vector{iter_type}(undef, n) - el, state = it - reservoir[1] = el - @inbounds for i in 2:n - it = iterate(iter, state) - isnothing(it) && return sample(rng, resize!(reservoir, i-1), n, ordered=true) - el, state = it - reservoir[i] = el - end - o = [i for i in 1:n] - reservoir = sample(rng, reservoir, n, ordered=true) - i = n - @inbounds while true - skip_k = skip(rng, i, n) - it = skip_ahead_unknown_end(iter, state, skip_k) - isnothing(it) && return reservoir[sortperm(o)] - el, state = it - i += skip_k + 1 - p = 1/i - z = (1-p)^(n-3) - q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1)) - k = choose(n, p, q, z) - if k == 1 - r = rand(rng, 1:n) - reservoir[r] = el - o[r] = i - else - for j in 1:k - r = rand(rng, j:n) - reservoir[r] = el - reservoir[r], reservoir[j] = reservoir[j], reservoir[r] - o[r], o[j] = o[j], i + update_order_multi!(i, r, j, order, is) end end end @@ -253,3 +189,35 @@ function skip_ahead_unknown_end(iter, state, n) isnothing(it) && return nothing return it end + +instantiate_order(n, ::Union{WORSample, WRSample}) = nothing, nothing +function instantiate_order(n, ::Union{OrdWORSample, OrdWRSample}) + return n, [i for i in 1:n] +end + +update_order!(k, skip_k, q, order, ::WORSample) = nothing +function update_order!(k, skip_k, q, order, ::OrdWORSample) + k += skip_k + 1 + order[q] = k + return k +end + +update_order_single!(k, r, order, ::WRSample) = nothing +function update_order_single!(k, r, order, ::OrdWRSample) + order[r] = k +end + +update_order_multi!(k, r, j, order, ::WRSample) = nothing +function update_order_multi!(k, r, j, order, ::OrdWRSample) + order[r], order[j] = order[j], k +end + +function transform(rng, reservoir, order, ::Union{WORSample, WRSample}) + return shuffle!(rng, reservoir) +end +function transform(rng, reservoir, order, ::Union{OrdWORSample, OrdWRSample}) + return reservoir[sortperm(order)] +end +function transform(rng, reservoir, order::Nothing, ::Union{OrdWORSample, OrdWRSample}) + return reservoir +end