From f33800cc6305f3501ae4f065c425caa57ac28e37 Mon Sep 17 00:00:00 2001 From: Tim Wheeler Date: Sun, 1 Dec 2024 08:27:08 -0800 Subject: [PATCH] decode for linear disc should propagate sampling method --- src/linear_discretizer.jl | 4 ++-- test/test_linear_discretizer.jl | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/linear_discretizer.jl b/src/linear_discretizer.jl index 9648841..dc80b2a 100644 --- a/src/linear_discretizer.jl +++ b/src/linear_discretizer.jl @@ -121,10 +121,10 @@ decode(ld::LinearDiscretizer{N,D}, d::D) where {N<:Integer,D<:Integer} = decode( decode(ld::LinearDiscretizer{N,D}, d::I, method::AbstractSampleMethod=SAMPLE_UNIFORM) where {N<:Real,D<:Integer,I<:Integer} = decode(ld, convert(D,d), method) -function decode(ld::LinearDiscretizer{N,D}, data::AbstractArray{D}, ::AbstractSampleMethod=SAMPLE_UNIFORM) where {N<:Real,D<:Integer} +function decode(ld::LinearDiscretizer{N,D}, data::AbstractArray{D}, method::AbstractSampleMethod=SAMPLE_UNIFORM) where {N<:Real,D<:Integer} arr = Vector{N}(undef, length(data)) for (i,d) in enumerate(data) - arr[i] = decode(ld, d) + arr[i] = decode(ld, d, method) end reshape(arr, size(data)) end diff --git a/test/test_linear_discretizer.jl b/test/test_linear_discretizer.jl index 9cac70b..aeaf535 100644 --- a/test/test_linear_discretizer.jl +++ b/test/test_linear_discretizer.jl @@ -140,3 +140,20 @@ ld = LinearDiscretizer([0,10,20], Int, force_outliers_to_closest=false) @test supports_encoding(ld, 16) @test !supports_encoding(ld, -1) @test !supports_encoding(ld, 21) + +let + # Ensure that linear discretizer propoagates the sampling method + binedges = [0, 0.5, 1] + lineardisc = LinearDiscretizer(binedges) + + decode_1 = decode(lineardisc, 1, SAMPLE_BIN_CENTER) # 0.25 + decode_2 = decode(lineardisc, 2, SAMPLE_BIN_CENTER) # 0.75 + decode_12_uniform = decode(lineardisc, [1, 2], SAMPLE_UNIFORM) + + @test !(decode_1 ≈ decode_12_uniform[1]) + @test !(decode_2 ≈ decode_12_uniform[2]) + + decode_12_center = decode(lineardisc, [1, 2], SAMPLE_BIN_CENTER) + @test decode_1 ≈ decode_12_center[1] + @test decode_2 ≈ decode_12_center[2] +end \ No newline at end of file