Skip to content

Commit

Permalink
Update UnweightedSamplingSingle.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Apr 26, 2024
1 parent fe846f2 commit 5957de8
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
end

@inline function update!(s::SampleSingleAlgR, el)
@imm_reset s.seen_k += 1
s.seen_k += 1
if rand(s.rng) <= 1/s.seen_k
@imm_reset s.value = el
s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgL, el)
s.seen_k += 1
if s.skip_k > 0
@imm_reset s.skip_k -= 1
s.skip_k -= 1
else
@imm_reset s.value = el
@imm_reset s.state *= rand(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
s.value = el
s.state *= rand(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
end
return s
end
Expand All @@ -73,22 +73,25 @@ function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle)
return s1
end

function itsample(iter, method::ReservoirAlgorithm = algL;
function itsample(iter, method::ReservoirAlgorithm = algL,
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, method; iter_type)
return itsample(Random.default_rng(), iter, method, iter_type)
end
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL;
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL,
iter_type = infer_eltype(iter))
if Base.IteratorSize(iter) isa Base.SizeUnknown
return reservoir_sample(rng, iter, method; iter_type)
return reservoir_sample(rng, iter, iter_type, method)
else
return sortedindices_sample(rng, iter)
end
end

function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
function reservoir_sample(rng, iter, iter_type, method::ReservoirAlgorithm = algL)
s = ReservoirSample(rng, iter_type, method)
return update_all!(s, iter)
end

function update_all!(s, iter)
for x in iter
s = update!(s, x)
end
Expand Down

0 comments on commit 5957de8

Please sign in to comment.