diff --git a/test/common.jl b/test/common.jl index 63e2309..e445e47 100644 --- a/test/common.jl +++ b/test/common.jl @@ -129,9 +129,9 @@ end @test t isa CircularArraySLARTTraces end -@testset "CircularPrioritizedTraces" begin +@testset "CircularPrioritizedTraces-SARTS" begin t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; + CircularArraySARTSTraces(; capacity=3 ), default_priority=1.0f0 @@ -161,8 +161,68 @@ end @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 #EpisodesBuffer + t = CircularPrioritizedTraces( + CircularArraySARTSTraces(; + capacity=10 + ), + default_priority=1.0f0 + ) + + eb = EpisodesBuffer(t) + push!(eb, (state = 1, action = 1)) + for i = 1:5 + push!(eb, (state = i+1, action =i+1, reward = i, terminal = false)) + end + push!(eb, (state = 7, action = 7)) + for (j,i) = enumerate(8:11) + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + end + s = BatchSampler(1000) + b = sample(s, eb) + cm = counter(b[:state]) + @test !haskey(cm, 6) + @test !haskey(cm, 11) + @test all(in(keys(cm)), [1:5;7:10]) + + + eb[:priority, [1, 2]] = [0, 0] + @test eb[:priority] == [zeros(2);ones(8)] +end + +@testset "CircularPrioritizedTraces-SARTSA" begin t = CircularPrioritizedTraces( CircularArraySARTSATraces(; + capacity=3 + ), + default_priority=1.0f0 + ) + + push!(t, (state=0, action=0)) + + for i in 1:5 + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) + end + + @test length(t) == 3 + + s = BatchSampler(5) + + b = sample(s, t) + + t[:priority, [1, 2]] = [0, 0] + + # shouldn't be changed since [1,2] are old keys + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] + + t[:priority, [3, 4, 5]] = [0, 1, 0] + + b = sample(s, t) + + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 + + #EpisodesBuffer + t = CircularPrioritizedTraces( + CircularArraySARTSTraces(; capacity=10 ), default_priority=1.0f0