Skip to content

Commit

Permalink
New test for CircularPrioritizedTraces with SARTSA
Browse files Browse the repository at this point in the history
  • Loading branch information
Dharanish committed May 9, 2024
1 parent 73f2efb commit fbf054a
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ end
@test t isa CircularArraySLARTTraces
end

@testset "CircularPrioritizedTraces" begin
@testset "CircularPrioritizedTraces-SARTS" begin
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
CircularArraySARTSTraces(;
capacity=3
),
default_priority=1.0f0
Expand Down Expand Up @@ -161,8 +161,68 @@ end
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0

#EpisodesBuffer
t = CircularPrioritizedTraces(
CircularArraySARTSTraces(;
capacity=10
),
default_priority=1.0f0
)

eb = EpisodesBuffer(t)
push!(eb, (state = 1, action = 1))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = false))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
end
s = BatchSampler(1000)
b = sample(s, eb)
cm = counter(b[:state])
@test !haskey(cm, 6)
@test !haskey(cm, 11)
@test all(in(keys(cm)), [1:5;7:10])


eb[:priority, [1, 2]] = [0, 0]
@test eb[:priority] == [zeros(2);ones(8)]
end

@testset "CircularPrioritizedTraces-SARTSA" begin
t = CircularPrioritizedTraces(
CircularArraySARTSATraces(;
capacity=3
),
default_priority=1.0f0
)

push!(t, (state=0, action=0))

for i in 1:5
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
end

@test length(t) == 3

s = BatchSampler(5)

b = sample(s, t)

t[:priority, [1, 2]] = [0, 0]

# shouldn't be changed since [1,2] are old keys
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]

t[:priority, [3, 4, 5]] = [0, 1, 0]

b = sample(s, t)

@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0

#EpisodesBuffer
t = CircularPrioritizedTraces(
CircularArraySARTSTraces(;
capacity=10
),
default_priority=1.0f0
Expand Down

0 comments on commit fbf054a

Please sign in to comment.