From 22ee933e157c1f9f0144428f33cbf0b5428536e4 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Sat, 13 Jan 2024 19:32:45 +0100 Subject: [PATCH] Compatibility with ResumableFunctions.jl (#25) --- Project.toml | 2 +- src/IteratorSampling.jl | 1 + src/UnweightedSamplingMulti.jl | 14 +++++++++++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 184489b..2d86fae 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.2" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] @@ -21,4 +22,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["HypothesisTests", "Test", "BenchmarkTools", "StableRNGs"] - diff --git a/src/IteratorSampling.jl b/src/IteratorSampling.jl index 2b7928d..e190128 100644 --- a/src/IteratorSampling.jl +++ b/src/IteratorSampling.jl @@ -2,6 +2,7 @@ module IteratorSampling using Distributions using Random +using ResumableFunctions using StatsBase struct WRSample end diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 0ce87d4..a6d8f8c 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -5,7 +5,7 @@ end function itsample(rng::AbstractRNG, iter, n::Int; replace = false, ordered = false) - iter_type = Base.@default_eltype(iter) + iter_type = calculate_eltype(iter) if Base.IteratorSize(iter) isa Base.SizeUnknown reservoir_sample(rng, iter, n; replace, ordered)::Vector{iter_type} else @@ -30,7 +30,7 @@ function reservoir_sample(rng, iter, n; replace = false, ordered = false) end function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample}) - iter_type = Base.@default_eltype(iter) + iter_type = calculate_eltype(iter) it = iterate(iter) isnothing(it) && return iter_type[] el, state = it @@ -58,7 +58,7 @@ function reservoir_sample(rng, iter, n::Int, is::Union{WORSample, OrdWORSample}) end function reservoir_sample(rng, iter, n::Int, is::Union{WRSample, OrdWRSample}) - iter_type = Base.@default_eltype(iter) + iter_type = calculate_eltype(iter) it = iterate(iter) isnothing(it) && return iter_type[] reservoir = Vector{iter_type}(undef, n) @@ -221,3 +221,11 @@ end function transform(rng, reservoir, order::Nothing, ::Union{OrdWORSample, OrdWRSample}) return reservoir end + +function calculate_eltype(iter) + return Base.@default_eltype(iter) +end +function calculate_eltype(iter::ResumableFunctions.FiniteStateMachineIterator) + return eltype(iter) +end +