Skip to content

Commit

Permalink
Fix weighted sampling single (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 16, 2024
1 parent ab98e76 commit 8a12a01
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@

mutable struct WeightedResSampleSingle{T} <: AbstractReservoirSample
mutable struct WeightedResSampleSingle{T,R} <: AbstractReservoirSample
state::Float64
skip_w::Float64
rng::R
value::T
WeightedResSampleSingle{T}(state, skip_w) where T = new{T}(state, skip_w)
WeightedResSampleSingle{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end

function WeightedReservoirSample(T)
return WeightedResSampleSingle{T}(0.0, 0.0)
function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG}
return WeightedResSampleSingle{T,R}(0.0, 0.0, rng)
end

function value(s::WeightedResSampleSingle)
s.state === 0.0 && return nothing
return s.value
end

update!(s::WeightedResSampleSingle, el, weight) = update!(Random.default_rng(), s, el, weight)
function update!(rng, s::WeightedResSampleSingle, el, weight)
function update!(s::WeightedResSampleSingle, el, weight)
s.state += weight
if s.skip_w < s.state
s.value = el
s.skip_w = skip(rng, s.state, 1)
s.skip_w = skip(s.rng, s.state, 1)
end
return s
end
Expand All @@ -30,9 +30,9 @@ function itsample(iter, wv::Function)
end

function itsample(rng::AbstractRNG, iter, wv::Function)
s = WeightedReservoirSample(Base.@default_eltype(iter))
s = ReservoirSample(rng, Base.@default_eltype(iter), algAExpJ)
for x in iter
@inline update!(rng, s, x, wv(x))
@inline update!(s, x, wv(x))
end
return value(s)
end

0 comments on commit 8a12a01

Please sign in to comment.