diff --git a/test/episodes.jl b/test/episodes.jl index e95c420..f93c433 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -285,6 +285,7 @@ using Test show(eb); =# end + @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin eb = EpisodesBuffer( ElasticArraySARTSATraces() @@ -312,8 +313,8 @@ using Test @test eb.step_numbers[end] == 1 @test eb.sampleable_inds == [1,1,1,1,1,0,0] @test eb[:action][6] == 6 - @test eb[:next_action][6] == 6 - @test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero + @test eb[:next_action][5] == 6 + @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 for (j,i) = enumerate(8:11) ep2_len += 1 @@ -325,7 +326,6 @@ using Test end @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] @test length(eb.traces) == 9 #an action is missing at this stage - #three last steps replace oldest steps in the buffer. for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @@ -335,18 +335,18 @@ using Test @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end push!(eb, PartialNamedTuple((action = 13,))) - @test length(eb.traces) == 10 + @test length(eb.traces) == 12 #episode 1 - for (i,s) in enumerate(3:13) - if i in (4, 11) + for i in 1:13 + if i in (6, 13) @test eb.sampleable_inds[i] == 0 continue else @test eb.sampleable_inds[i] == 1 end b = eb[i] - @test b[:state] == b[:action] == b[:reward] == s - @test b[:next_state] == b[:next_action] == s + 1 + @test b[:state] == b[:action] == b[:reward] == i + @test b[:next_state] == b[:next_action] == i + 1 end #episode 2 #start a third episode @@ -360,9 +360,9 @@ using Test push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) end push!(eb, PartialNamedTuple((action = 26,))) - @test eb.sampleable_inds == [fill(true, 10); [false]] - @test eb.episodes_lengths == fill(length(15:26), 11) - @test eb.step_numbers == [3:13;] + @test eb.sampleable_inds[end-10:end] == [fill(true, 10); [false]] + @test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11) + @test eb.step_numbers[end-10:end] == [3:13;] #= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays step = popfirst!(eb) @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 diff --git a/test/traces.jl b/test/traces.jl index 762c6f2..f30b311 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -139,7 +139,6 @@ end @testset "build_trace_index ElasticArraySARTSATraces" begin t1 = ElasticArraySARTSATraces(; - capacity=3, state=Float32 => (2, 3), action=Float32 => (2,), reward=Float32 => (), @@ -171,6 +170,29 @@ end @test size(Base.getindex(t1, 1).state) == (2,3) + t2 = Traces(; a=[2, 3], b=[false, true]) + push!(t2, Val(:a), 5) + @test t2[:a][3] == 5 + + @test size(Base.getindex(t2, :a)) == (3,) + @test Base.getindex(t2, 1) == (; a = 2, b= false) +end + + +@testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin + t1 = ElasticArraySARTSATraces( + state=Float32 => (2, 3), + action=Float32 => (2,), + reward=Float32 => (), + terminal=Bool => () + ) + push!(t1, Val(:reward), 5) + @test t1[:reward][1] == 5 + + @test size(Base.getindex(t1, :reward)) == (1,) + @test size(Base.getindex(t1, :state)) == (0,) + + t2 = Traces(; a=[2, 3], b=[false, true]) push!(t2, Val(:a), 5) @test t2[:a][3] == 5