Skip to content

Commit

Permalink
Merge pull request #66 from JuliaReinforcementLearning/jpsl/elastic
Browse files Browse the repository at this point in the history
Add Missing Elastic Methods and tests
  • Loading branch information
jeremiahpslewis authored Mar 22, 2024
2 parents 9625f16 + 4895a41 commit f222bad
Show file tree
Hide file tree
Showing 12 changed files with 339 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
- '^1.11.0-alpha'
- 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.7"
version = "0.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
export ElasticArraySARTTraces
export ElasticArraySARTSATraces

using ElasticArrays: ElasticArray, resize_lastdim!

const ElasticArraySARTTraces = Traces{
const ElasticArraySARTSATraces = Traces{
SS′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
Expand All @@ -12,7 +10,7 @@ const ElasticArraySARTTraces = Traces{
}
}

function ElasticArraySARTTraces(;
function ElasticArraySARTSATraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
Expand All @@ -31,10 +29,3 @@ function ElasticArraySARTTraces(;
)
end

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
30 changes: 30 additions & 0 deletions src/common/ElasticArraySARTSTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export ElasticArraySARTSTraces

const ElasticArraySARTSTraces = Traces{
SS′ART,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySARTSTraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
Traces(
action = ElasticArray{action_eltype}(undef, action_size..., 0),
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
35 changes: 35 additions & 0 deletions src/common/ElasticArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
export ElasticArraySLARTTraces

const ElasticArraySLARTTraces = Traces{
SS′LL′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{LL′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
5 changes: 4 additions & 1 deletion src/common/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ include("CircularArraySARTSTraces.jl")
include("CircularArraySARTSATraces.jl")
include("CircularArraySLARTTraces.jl")
include("CircularPrioritizedTraces.jl")
include("ElasticArraySARTTraces.jl")
include("common_elastic_array.jl")
include("ElasticArraySARTSTraces.jl")
include("ElasticArraySARTSATraces.jl")
include("ElasticArraySLARTTraces.jl")
9 changes: 9 additions & 0 deletions src/common/common_elastic_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using ElasticArrays: ElasticArray, resize_lastdim!

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
49 changes: 30 additions & 19 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export EpisodesBuffer, PartialNamedTuple
import DataStructures.CircularBuffer
using ElasticArrays: ElasticArray, ElasticVector

"""
EpisodesBuffer(traces::AbstractTraces)
Expand Down Expand Up @@ -68,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces)
end
end

Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
Base.size(es::EpisodesBuffer) = size(es.traces)
Base.length(es::EpisodesBuffer) = length(es.traces)
Base.keys(es::EpisodesBuffer) = keys(es.traces)
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
function Base.getindex(es::EpisodesBuffer, idx::Int...)
@boundscheck all(es.sampleable_inds[idx...])
getindex(es.traces, idx...)
end

function Base.getindex(es::EpisodesBuffer, idx...)
getindex(es.traces, idx...)
end

Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
Base.size(eb::EpisodesBuffer) = size(eb.traces)
Base.length(eb::EpisodesBuffer) = length(eb.traces)
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces)
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
s = nameof(typeof(eb))
t = eb.traces
Expand All @@ -82,14 +91,16 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where
end

ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs)
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)

function pad!(trace::Trace)
pad!(trace.parent)
return nothing
end

pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))

Expand Down Expand Up @@ -123,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
return :($ex)
end

fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)

fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)

function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.traces, xs)
Expand Down Expand Up @@ -162,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(es::EpisodesBuffer)
$f(es.episodes_lengths)
$f(es.sampleable_inds)
$f(es.step_numbers)
$f(es.traces)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
$f(eb.sampleable_inds)
$f(eb.step_numbers)
$f(eb.traces)
end
end

function Base.empty!(es::EpisodesBuffer)
empty!(es.traces)
empty!(es.episodes_lengths)
empty!(es.sampleable_inds)
empty!(es.step_numbers)
function Base.empty!(eb::EpisodesBuffer)
empty!(eb.traces)
empty!(eb.episodes_lengths)
empty!(eb.sampleable_inds)
empty!(eb.step_numbers)
end
2 changes: 2 additions & 0 deletions src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces
import MacroTools: @forward

import CircularArrayBuffers.CircularArrayBuffer
using ElasticArrays: ElasticArray
import Adapt

#####
Expand Down Expand Up @@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent)
capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t)
capacity(::AbstractVector) = Inf
capacity(::ElasticArray) = Inf

#####

Expand Down
8 changes: 4 additions & 4 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ end
@test batch.terminal == Bool[0, 0, 0] |> gpu
end

@testset "ElasticArraySARTTraces" begin
t = ElasticArraySARTTraces(;
@testset "ElasticArraySARTSTraces" begin
t = ElasticArraySARTSTraces(;
state=Float32 => (2, 3),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)

@test t isa ElasticArraySARTTraces
@test t isa ElasticArraySARTSTraces

push!(t, (state=ones(Float32, 2, 3), action=1))
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
Expand Down Expand Up @@ -185,4 +185,4 @@ end

eb[:priority, [1, 2]] = [0, 0]
@test eb[:priority] == [zeros(2);ones(8)]
end
end
Loading

0 comments on commit f222bad

Please sign in to comment.