Skip to content

Commit

Permalink
Better methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Dec 31, 2023
1 parent 14ed2be commit e847225
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@

function itsample(iter, n::Int, iter_type = Base.@default_eltype(iter);
function itsample(iter, n::Int;
replace = false, ordered = false, is_stateful = false)
return itsample(Random.GLOBAL_RNG, iter, n;
replace=replace, ordered=ordered, is_stateful=is_stateful)
end

function itsample(rng::AbstractRNG, iter, n::Int, iter_type = Base.@default_eltype(iter);
function itsample(rng::AbstractRNG, iter, n::Int;
replace = false, ordered = false, is_stateful = false)
IterHasKnownSize = Base.IteratorSize(iter)
if IterHasKnownSize isa NonIndexable
if is_stateful
if replace
error("Not implemented yet")
else
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered), iter_type)
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
end
else
double_scan_sampling(rng, iter, n, replace, ordered, iter_type)
double_scan_sampling(rng, iter, n, replace, ordered)
end
else
single_scan_sampling(rng, iter, n, replace, ordered, iter_type)
single_scan_sampling(rng, iter, n, replace, ordered)
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false}, iter_type)
function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false})
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -53,7 +54,8 @@ function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false}, iter_typ
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true}, iter_type)
function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true})
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -89,16 +91,16 @@ function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true}, iter_type
end
end

function double_scan_sampling(rng, iter, n::Int, replace, ordered, iter_type)
function double_scan_sampling(rng, iter, n::Int, replace, ordered)
N = get_population_size(iter)
single_scan_sampling(rng, iter, n, N, replace, ordered, iter_type)
single_scan_sampling(rng, iter, n, N, replace, ordered)
end

function single_scan_sampling(rng, iter, n::Int, replace, ordered, iter_type)
return single_scan_sampling(rng, iter, n, length(iter), replace, ordered, iter_type)
function single_scan_sampling(rng, iter, n::Int, replace, ordered)
return single_scan_sampling(rng, iter, n, length(iter), replace, ordered)
end

function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered, iter_type)
function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered)
if N <= n
reservoir = collect(iter)
if ordered
Expand All @@ -107,6 +109,7 @@ function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered, iter_
return shuffle!(reservoir)
end
end
iter_type = Base.@default_eltype(iter)
it = iterate(iter)
el, state = it
reservoir = Vector{iter_type}(undef, n)
Expand Down

0 comments on commit e847225

Please sign in to comment.