Skip to content

Commit

Permalink
improve perf of reservoir_sample when replace=true + some fixes (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Jan 6, 2024
1 parent d1e30e0 commit 5a254ea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
53 changes: 34 additions & 19 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ function reservoir_sample(rng, iter, n::Int, ::Val{false}, ::Val{false})
skip_counter = ceil(Int, randexp(rng)/log(1-w))
while skip_counter != 0
skip_res = iterate(iter, state)
isnothing(skip_res) && return shuffle!(reservoir)
isnothing(skip_res) && return shuffle!(rng, reservoir)
state = skip_res[2]
skip_counter += 1
end
it = iterate(iter, state)
isnothing(it) && return shuffle!(reservoir)
isnothing(it) && return shuffle!(rng, reservoir)
el, state = it
@inbounds reservoir[rand(rng, 1:n)] = el
u += randexp(rng)
Expand Down Expand Up @@ -85,28 +85,34 @@ function reservoir_sample(rng, iter, n::Int, ::Val{true}, ::Val{false})
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
reservoir = Vector{iter_type}(undef, n)
for i in eachindex(reservoir)
el, state = it
reservoir[1] = el
@inbounds for i in 2:n
it = iterate(iter, state)
isnothing(it) && return sample(rng, resize!(reservoir, i-1), n)
el, state = it
reservoir[i] = el
end
i = 1
reservoir = sample(rng, reservoir, n)
i = n
while true
t = skip(rng, i, n)
skip_counter = t
while skip_counter != 0
skip_res = iterate(iter, state)
isnothing(skip_res) && return shuffle!(reservoir)
isnothing(skip_res) && return shuffle!(rng, reservoir)
state = skip_res[2]
skip_counter -= 1
end
it = iterate(iter, state)
isnothing(it) && return shuffle!(reservoir)
isnothing(it) && return shuffle!(rng, reservoir)
el, state = it
i += t + 1
p = 1/i
q = rand(rng, Uniform((1-p)^n,1))
k = choose(n, p, q)
z = (1-p)^(n-2)
q = rand(rng, Uniform(z*(1-p)*(1-p),1))
k = choose(n, p, q, z)
if k == 1
r = rand(rng, 1:n)
@inbounds reservoir[r] = el
Expand Down Expand Up @@ -145,8 +151,9 @@ function reservoir_sample(rng, iter, n::Int, ::Val{true}, ::Val{true})
el, state = it
i += t + 1
p = 1/i
q = rand(rng, Uniform((1-p)^n,1))
k = choose(n, p, q)
z = (1-p)^(n-3)
q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1))
k = choose(n, p, q, z)
if k == 1
r = rand(rng, 1:n)
@inbounds reservoir[r] = el
Expand All @@ -168,11 +175,15 @@ function skip(rng, n, m)
return t
end

function choose(n, p, q)
z = (1-p)^n + n*p*((1-p)^(n-1))
function choose(n, p, q, z)
m = 1-p
s = z
z = s*m*m*(m + n*p)
z > q && return 1
z += n*(n-1)*p*p*((1-p)^(n-2))/2
z += n*p*(n-1)*p*s*m/2
z > q && return 2
z += n*p*(n-1)*p*(n-2)*p*s/6
z > q && return 3
b = Binomial(n, p)
return quantile(b, q)
end
Expand All @@ -189,10 +200,14 @@ end
function sortedindices_sample(rng, iter, n::Int, N::Int, replace, ordered)
if N <= n
reservoir = collect(iter)
if ordered
return reservoir
if replace
return sample(rng, reservoir, n, ordered=ordered)
else
return shuffle!(reservoir)
if ordered
return reservoir
else
return shuffle!(reservoir)
end
end
end
iter_type = Base.@default_eltype(iter)
Expand Down Expand Up @@ -233,7 +248,7 @@ function sortedindices_sample(rng, iter, n::Int, N::Int, replace, ordered)
if ordered
return reservoir
else
return shuffle!(reservoir)
return shuffle!(rng, reservoir)
end
end

Expand All @@ -254,4 +269,4 @@ function get_sorted_indices(rng, n, N, replace)
else
return sort!(sample(rng, 1:N, n; replace=replace))
end
end
end
2 changes: 1 addition & 1 deletion test/unweighted_sampling_multi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@test all(x -> a <= x <= b, s)
@test typeof(s) == Vector{Int}
s = itsample(iter, 100, replace=replace, ordered=ordered)
@test length(s) == 10
@test replace ? length(s) == 100 : length(s) == 10
@test length(unique(s)) == 10

# test return values of iter with unknown lengths are inrange
Expand Down

0 comments on commit 5a254ea

Please sign in to comment.