Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremiah Lewis committed Mar 22, 2024
1 parent bd17951 commit 1cc7b76
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
22 changes: 11 additions & 11 deletions test/episodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ using Test
show(eb);
=#
end

@testset "ElasticArraySARTSATraces with PartialNamedTuple" begin
eb = EpisodesBuffer(
ElasticArraySARTSATraces()
Expand Down Expand Up @@ -312,8 +313,8 @@ using Test
@test eb.step_numbers[end] == 1
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
@test eb[:action][6] == 6
@test eb[:next_action][6] == 6
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
@test eb[:next_action][5] == 6
@test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
ep2_len = 0
for (j,i) = enumerate(8:11)
ep2_len += 1
Expand All @@ -325,7 +326,6 @@ using Test
end
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
@test length(eb.traces) == 9 #an action is missing at this stage
#three last steps replace oldest steps in the buffer.
for (i, s) = enumerate(12:13)
ep2_len += 1
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
Expand All @@ -335,18 +335,18 @@ using Test
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
end
push!(eb, PartialNamedTuple((action = 13,)))
@test length(eb.traces) == 10
@test length(eb.traces) == 12
#episode 1
for (i,s) in enumerate(3:13)
if i in (4, 11)
for i in 1:13
if i in (6, 13)
@test eb.sampleable_inds[i] == 0
continue
else
@test eb.sampleable_inds[i] == 1
end
b = eb[i]
@test b[:state] == b[:action] == b[:reward] == s
@test b[:next_state] == b[:next_action] == s + 1
@test b[:state] == b[:action] == b[:reward] == i
@test b[:next_state] == b[:next_action] == i + 1
end
#episode 2
#start a third episode
Expand All @@ -360,9 +360,9 @@ using Test
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
end
push!(eb, PartialNamedTuple((action = 26,)))
@test eb.sampleable_inds == [fill(true, 10); [false]]
@test eb.episodes_lengths == fill(length(15:26), 11)
@test eb.step_numbers == [3:13;]
@test eb.sampleable_inds[end-10:end] == [fill(true, 10); [false]]
@test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11)
@test eb.step_numbers[end-10:end] == [3:13;]
#= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays
step = popfirst!(eb)
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
Expand Down
24 changes: 23 additions & 1 deletion test/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ end

@testset "build_trace_index ElasticArraySARTSATraces" begin
t1 = ElasticArraySARTSATraces(;
capacity=3,
state=Float32 => (2, 3),
action=Float32 => (2,),
reward=Float32 => (),
Expand Down Expand Up @@ -171,6 +170,29 @@ end
@test size(Base.getindex(t1, 1).state) == (2,3)


t2 = Traces(; a=[2, 3], b=[false, true])
push!(t2, Val(:a), 5)
@test t2[:a][3] == 5

@test size(Base.getindex(t2, :a)) == (3,)
@test Base.getindex(t2, 1) == (; a = 2, b= false)
end


@testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin
t1 = ElasticArraySARTSATraces(
state=Float32 => (2, 3),
action=Float32 => (2,),
reward=Float32 => (),
terminal=Bool => ()
)
push!(t1, Val(:reward), 5)
@test t1[:reward][1] == 5

@test size(Base.getindex(t1, :reward)) == (1,)
@test size(Base.getindex(t1, :state)) == (0,)


t2 = Traces(; a=[2, 3], b=[false, true])
push!(t2, Val(:a), 5)
@test t2[:a][3] == 5
Expand Down

0 comments on commit 1cc7b76

Please sign in to comment.