Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Missing Elastic Methods and tests #66

Merged
merged 11 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.7"
version = "0.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
export ElasticArraySARTTraces
export ElasticArraySARTSATraces

using ElasticArrays: ElasticArray, resize_lastdim!

const ElasticArraySARTTraces = Traces{
const ElasticArraySARTSATraces = Traces{
SS′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
Expand All @@ -12,7 +10,7 @@ const ElasticArraySARTTraces = Traces{
}
}

function ElasticArraySARTTraces(;
function ElasticArraySARTSATraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
Expand All @@ -31,10 +29,3 @@ function ElasticArraySARTTraces(;
)
end

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
30 changes: 30 additions & 0 deletions src/common/ElasticArraySARTSTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export ElasticArraySARTSTraces

const ElasticArraySARTSTraces = Traces{
SS′ART,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySARTSTraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
Traces(
action = ElasticArray{action_eltype}(undef, action_size..., 0),
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
35 changes: 35 additions & 0 deletions src/common/ElasticArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
export ElasticArraySLARTTraces

const ElasticArraySLARTTraces = Traces{
SS′LL′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{LL′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
5 changes: 4 additions & 1 deletion src/common/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ include("CircularArraySARTSTraces.jl")
include("CircularArraySARTSATraces.jl")
include("CircularArraySLARTTraces.jl")
include("CircularPrioritizedTraces.jl")
include("ElasticArraySARTTraces.jl")
include("common_elastic_array.jl")
include("ElasticArraySARTSTraces.jl")
include("ElasticArraySARTSATraces.jl")
include("ElasticArraySLARTTraces.jl")
9 changes: 9 additions & 0 deletions src/common/common_elastic_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using ElasticArrays: ElasticArray, resize_lastdim!

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
3 changes: 3 additions & 0 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export EpisodesBuffer, PartialNamedTuple
import DataStructures.CircularBuffer
using ElasticArrays: ElasticArray, ElasticVector

"""
EpisodesBuffer(traces::AbstractTraces)
Expand Down Expand Up @@ -90,6 +91,8 @@ function pad!(trace::Trace)
return nothing
end

pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))

Expand Down
2 changes: 2 additions & 0 deletions src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces
import MacroTools: @forward

import CircularArrayBuffers.CircularArrayBuffer
using ElasticArrays: ElasticArray
import Adapt

#####
Expand Down Expand Up @@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent)
capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t)
capacity(::AbstractVector) = Inf
capacity(::ElasticArray) = Inf

#####

Expand Down
8 changes: 4 additions & 4 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ end
@test batch.terminal == Bool[0, 0, 0] |> gpu
end

@testset "ElasticArraySARTTraces" begin
t = ElasticArraySARTTraces(;
@testset "ElasticArraySARTSTraces" begin
t = ElasticArraySARTSTraces(;
state=Float32 => (2, 3),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)

@test t isa ElasticArraySARTTraces
@test t isa ElasticArraySARTSTraces

push!(t, (state=ones(Float32, 2, 3), action=1))
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
Expand Down Expand Up @@ -185,4 +185,4 @@ end

eb[:priority, [1, 2]] = [0, 0]
@test eb[:priority] == [zeros(2);ones(8)]
end
end
179 changes: 179 additions & 0 deletions test/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ReinforcementLearningTrajectories
using CircularArrayBuffers
using Test

@testset "EpisodesBuffer" begin
@testset "with circular traces" begin
eb = EpisodesBuffer(
Expand Down Expand Up @@ -113,6 +114,8 @@ using Test
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
@test eb[:action][6] == 6
@test eb[:next_action][6] == 6
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
ep2_len = 0
for (j,i) = enumerate(8:11)
Expand Down Expand Up @@ -197,4 +200,180 @@ using Test
@test eb.step_numbers == [1:16;1:16]
@test length(eb) == 31
end
@testset "with ElasticArraySARTSTraces traces" begin
eb = EpisodesBuffer(
ElasticArraySARTSTraces()
)
#push a first episode l=5
push!(eb, (state = 1,))
@test eb.sampleable_inds[end] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == i + 1
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
end
@test eb.sampleable_inds == [1,1,1,1,1,0]
@test length(eb.traces) == 5
#start new episode of 6 periods.
push!(eb, (state = 7,))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
@test eb[6][:reward] == 0 #6 is not a valid index, the reward there is filled as zero
ep2_len = 0
for (j,i) = enumerate(8:11)
ep2_len += 1
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == j + 1
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
end
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
@test length(eb.traces) == 10

for (i, s) = enumerate(12:13)
ep2_len += 1
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == i + 1 + 4
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
end
#episode 1
for i in 3:13
if i in (6, 13)
@test eb.sampleable_inds[i] == 0
continue
else
@test eb.sampleable_inds[i] == 1
end
b = eb[i]
@test b[:state] == b[:action] == b[:reward] == i
@test b[:next_state] == i + 1
end
#episode 2
#start a third episode
push!(eb, (state = 14, ))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1

for (i,s) in enumerate(15:26)
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
end
@test eb.sampleable_inds[end-5:end] == [fill(true, 5); [false]]
@test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11)
@test eb.step_numbers[end-10:end] == [3:13;]
#= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays
step = popfirst!(eb)
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
@test first(eb.step_numbers) == 4
step = pop!(eb)
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8
@test last(eb.step_numbers) == 12
@test size(eb) == size(eb.traces) == (8,)
empty!(eb)
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
show(eb);
=#
end

@testset "ElasticArraySARTSATraces with PartialNamedTuple" begin
eb = EpisodesBuffer(
ElasticArraySARTSATraces()
)
#push a first episode l=5
push!(eb, (state = 1,))
@test eb.sampleable_inds[end] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
for i = 1:5
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == i + 1
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
end
push!(eb, PartialNamedTuple((action = 6,)))
@test eb.sampleable_inds == [1,1,1,1,1,0]
@test length(eb.traces) == 5
#start new episode of 6 periods.
push!(eb, (state = 7,))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
@test eb.sampleable_inds == [1,1,1,1,1,0,0]
@test eb[:action][6] == 6
@test eb[:next_action][5] == 6
@test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero
ep2_len = 0
for (j,i) = enumerate(8:11)
ep2_len += 1
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == j + 1
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
end
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
@test length(eb.traces) == 9 #an action is missing at this stage
for (i, s) = enumerate(12:13)
ep2_len += 1
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 1
@test eb.step_numbers[end] == i + 1 + 4
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
end
push!(eb, PartialNamedTuple((action = 13,)))
@test length(eb.traces) == 12
#episode 1
for i in 1:13
if i in (6, 13)
@test eb.sampleable_inds[i] == 0
continue
else
@test eb.sampleable_inds[i] == 1
end
b = eb[i]
@test b[:state] == b[:action] == b[:reward] == i
@test b[:next_state] == b[:next_action] == i + 1
end
#episode 2
#start a third episode
push!(eb, (state = 14,))
@test eb.sampleable_inds[end] == 0
@test eb.sampleable_inds[end-1] == 0
@test eb.episodes_lengths[end] == 0
@test eb.step_numbers[end] == 1
#push until it reaches it own start
for (i,s) in enumerate(15:26)
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
end
push!(eb, PartialNamedTuple((action = 26,)))
@test eb.sampleable_inds[end-10:end] == [fill(true, 10); [false]]
@test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11)
@test eb.step_numbers[end-10:end] == [3:13;]
#= Deactivated until https://github.com/JuliaArrays/ElasticArrays.jl/pull/56/files merged and pop!/popfirst! added to ElasticArrays
step = popfirst!(eb)
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9
@test first(eb.step_numbers) == 4
step = pop!(eb)
@test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8
@test last(eb.step_numbers) == 12
@test size(eb) == size(eb.traces) == (8,)
empty!(eb)
@test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers)
show(eb);
=#
end
end
Loading
Loading