Skip to content

Commit

Permalink
Add update! for weighted sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 16, 2024
1 parent 8a12a01 commit e0af817
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 197 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,28 @@ julia> iter = Iterators.filter(x -> x != 10, 1:10^7);
julia> wv(el) = 1.0

julia> @btime itsample($rng, $iter, 10^4, algRSWRSKIP);
9.675 ms (4 allocations: 156.34 KiB)

julia> @btime itsample($rng, $iter, 10^4, algL);
7.889 ms (2 allocations: 78.17 KiB)

julia> @btime itsample($rng, $iter, $wv, 10^4; replace=true);
12.493 ms (15 allocations: 547.23 KiB)

julia> @btime itsample($rng, $iter, $wv, 10^4; replace=false);
20.281 ms (5 allocations: 234.61 KiB)
14.579 ms (5 allocations: 156.39 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=true);
137.932 ms (20 allocations: 146.91 MiB)
136.973 ms (20 allocations: 146.91 MiB)

julia> @btime itsample($rng, $iter, 10^4, algL);
10.630 ms (3 allocations: 78.22 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=false);
139.212 ms (27 allocations: 147.05 MiB)
138.207 ms (27 allocations: 147.05 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, algWRSWRSKIP);
32.756 ms (5 allocations: 156.41 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=true);
315.508 ms (49 allocations: 675.21 MiB)
548.043 ms (45 allocations: 702.33 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, algAExpJ);
40.849 ms (11 allocations: 234.78 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false);
317.230 ms (43 allocations: 370.19 MiB)
316.312 ms (43 allocations: 370.19 MiB)
```

More information can be found in the [documentation](https://juliadynamics.github.io/StreamSampling.jl/stable/).
8 changes: 8 additions & 0 deletions src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ const worsample = WORSample()
const ordworsample = OrdWORSample()

abstract type AbstractReservoirSample end

# unweighted cases
abstract type AbstractReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWorReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWorReservoirSampleMulti <: AbstractWorReservoirSampleMulti end
abstract type AbstractWrReservoirSampleMulti <: AbstractReservoirSampleMulti end
abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMulti end

# weighted cases
abstract type AbstractWeightedReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractReservoirSample end


abstract type ReservoirAlgorithm end

struct AlgL <: ReservoirAlgorithm end
Expand Down
2 changes: 1 addition & 1 deletion src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function update!(s::AbstractWrReservoirSampleMulti, el)
elseif s.skip_k < 0
p = 1/s.seen_k
z = (1-p)^(n-3)
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1))
q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
k = choose(n, p, q, z)
@inbounds begin
if k == 1
Expand Down
Loading

0 comments on commit e0af817

Please sign in to comment.