Skip to content

Commit

Permalink
always valid for stateful iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Dec 31, 2023
1 parent 3af60d6 commit 86a4ebc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 46 deletions.
41 changes: 15 additions & 26 deletions src/IteratorSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,26 @@ include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")

"""
itsample([rng], iter, [condition::Function]; [alloc])
itsample([rng], iter)
Return a random element of the iterator, optionally specifying a `rng`
(which defaults to `Random.GLOBAL_RNG`) and a condition to restrict the
sampling on only those elements for which the function returns `true`.
If the iterator is empty or no random element satisfies the condition,
it returns `nothing`.
## Keywords
* `alloc = false`: this keyword chooses the algorithm to perform, if
`alloc = false` the algorithm doesn't allocate a new collection to
perform the sampling, which should be better when the number of elements is
large.
(which defaults to `Random.default_rng()`). If the iterator is empty, it
returns `nothing`.
-----
itsample([rng], iter, [condition::Function], n::Int; [alloc, iter_type])
Return a vector of `n` random elements of the iterator without replacement,
optionally specifying a `rng` (which defaults to `Random.GLOBAL_RNG`) and
a condition to restrict the sampling on only those elements for which the
function returns `true`. If the iterator has less than `n` elements or less
than `n` elements satisfy the condition, it returns a vector of these elements.
## Keywords
* `alloc = true`: when the function returns a vector, it happens to be much
better to use the allocating version for small iterators.
* `iter_type = Any`: the iterator type of the given iterator, if not given
it defaults to `Any`, which means that the returned vector will be also of
`Any` type. For performance reasons, if you can infer the type of the iterator,
it is better to pass it.
itsample([rng], iter, n::Int; replace = true, ordered = false)
Return a vector of `n` random elements of the iterator,
optionally specifying a `rng` (which defaults to `Random.default_rng()`).
`replace` dictates whether sampling is performed with replacement.
`ordered` dictates whether an ordered sample (also called a sequential
sample, i.e. a sample where items appear in the same order as in `iter`).
If the iterator has less than `n` elements, it returns a vector of
these elements.
"""
function itsample end

Expand Down
18 changes: 6 additions & 12 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@

function itsample(iter, n::Int;
replace = false, ordered = false, is_stateful = false)
return itsample(Random.GLOBAL_RNG, iter, n;
replace=replace, ordered=ordered, is_stateful=is_stateful)
function itsample(iter, n::Int; replace = false, ordered = false)
return itsample(Random.default_rng(), iter, n; replace=replace, ordered=ordered)
end

function itsample(rng::AbstractRNG, iter, n::Int;
replace = false, ordered = false, is_stateful = false)
replace = false, ordered = false)
IterHasKnownSize = Base.IteratorSize(iter)
if IterHasKnownSize isa NonIndexable
if is_stateful
if replace
error("Not implemented yet")
else
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
end
if replace
error("Not implemented yet")
else
double_scan_sampling(rng, iter, n, replace, ordered)
unweighted_resorvoir_sampling(rng, iter, n, Val(ordered))
end
else
single_scan_sampling(rng, iter, n, replace, ordered)
Expand Down
12 changes: 4 additions & 8 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@

function itsample(iter; is_stateful = false)
return itsample(Random.GLOBAL_RNG, iter; is_stateful = false)
function itsample(iter)
return itsample(Random.default_rng(), iter)
end

function itsample(rng::AbstractRNG, iter; is_stateful = false)
function itsample(rng::AbstractRNG, iter)
IterHasKnownSize = Base.IteratorSize(iter)
if IterHasKnownSize isa NonIndexable
if is_stateful
unweighted_resorvoir_sampling(rng, iter)
else
double_scan_sampling(rng, iter)
end
return unweighted_resorvoir_sampling(rng, iter)
else
return single_scan_sampling(rng, iter)
end
Expand Down

0 comments on commit 86a4ebc

Please sign in to comment.