Skip to content

Commit

Permalink
Compatibility with ResumableFunctions.jl (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Jan 13, 2024
1 parent 18cfa11 commit 22ee933
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.2.2"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Expand All @@ -21,4 +22,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["HypothesisTests", "Test",
"BenchmarkTools", "StableRNGs"]

1 change: 1 addition & 0 deletions src/IteratorSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module IteratorSampling

using Distributions
using Random
using ResumableFunctions
using StatsBase

struct WRSample end
Expand Down
14 changes: 11 additions & 3 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end

function itsample(rng::AbstractRNG, iter, n::Int;
replace = false, ordered = false)
iter_type = Base.@default_eltype(iter)
iter_type = calculate_eltype(iter)
if Base.IteratorSize(iter) isa Base.SizeUnknown
reservoir_sample(rng, iter, n; replace, ordered)::Vector{iter_type}
else
Expand All @@ -30,7 +30,7 @@ function reservoir_sample(rng, iter, n; replace = false, ordered = false)
end

function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample})
iter_type = Base.@default_eltype(iter)
iter_type = calculate_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -58,7 +58,7 @@ function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample})
end

function reservoir_sample(rng, iter, n::Int, is::Union{WRSample, OrdWRSample})
iter_type = Base.@default_eltype(iter)
iter_type = calculate_eltype(iter)
it = iterate(iter)
isnothing(it) && return iter_type[]
reservoir = Vector{iter_type}(undef, n)
Expand Down Expand Up @@ -221,3 +221,11 @@ end
function transform(rng, reservoir, order::Nothing, ::Union{OrdWORSample, OrdWRSample})
return reservoir
end

function calculate_eltype(iter)
return Base.@default_eltype(iter)
end
function calculate_eltype(iter::ResumableFunctions.FiniteStateMachineIterator)
return eltype(iter)
end

0 comments on commit 22ee933

Please sign in to comment.