Skip to content

Commit

Permalink
eltype to itsample
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Dec 30, 2023
1 parent c0719e9 commit 14ed2be
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@

function itsample(iter, n::Int;
function itsample(iter, n::Int, iter_type = Base.@default_eltype(iter);
replace = false, ordered = false, is_stateful = false)
return itsample(Random.GLOBAL_RNG, iter, n;
replace=replace, ordered=ordered, is_stateful=is_stateful)
end

function itsample(rng::AbstractRNG, iter, n::Int;
function itsample(rng::AbstractRNG, iter, n::Int, iter_type = Base.@default_eltype(iter);
replace = false, ordered = false, is_stateful = false)
IterHasKnownSize = Base.IteratorSize(iter)
if IterHasKnownSize isa NonIndexable
if is_stateful
if replace
error("Not implemented yet")
else
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered), iter_type)
end
else
double_scan_sampling(rng, iter, n, replace, ordered)
double_scan_sampling(rng, iter, n, replace, ordered, iter_type)
end
else
single_scan_sampling(rng, iter, n, replace, ordered)
single_scan_sampling(rng, iter, n, replace, ordered, iter_type)
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false},
iter_type = Base.@default_eltype(iter))
function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false}, iter_type)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -54,8 +53,7 @@ function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{false},
end
end

function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true},
iter_type = Base.@default_eltype(iter))
function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true}, iter_type)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -91,17 +89,16 @@ function unweighted_resorvoir_sampling(rng, iter, n::Int, ::Val{true},
end
end

function double_scan_sampling(rng, iter, n::Int, replace, ordered)
function double_scan_sampling(rng, iter, n::Int, replace, ordered, iter_type)
N = get_population_size(iter)
single_scan_sampling(rng, iter, n, N, replace, ordered)
single_scan_sampling(rng, iter, n, N, replace, ordered, iter_type)
end

function single_scan_sampling(rng, iter, n::Int, replace, ordered)
return single_scan_sampling(rng, iter, n, length(iter), replace, ordered)
function single_scan_sampling(rng, iter, n::Int, replace, ordered, iter_type)
return single_scan_sampling(rng, iter, n, length(iter), replace, ordered, iter_type)
end

function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered,
iter_type = Base.@default_eltype(iter))
function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered, iter_type)
if N <= n
reservoir = collect(iter)
if ordered
Expand Down

0 comments on commit 14ed2be

Please sign in to comment.