Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Missing Elastic Methods and tests #66

Merged
merged 11 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
- '^1.11.0-alpha'
- 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningTrajectories"
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
version = "0.3.7"
version = "0.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
export ElasticArraySARTTraces
export ElasticArraySARTSATraces

using ElasticArrays: ElasticArray, resize_lastdim!

const ElasticArraySARTTraces = Traces{
const ElasticArraySARTSATraces = Traces{
SS′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
Expand All @@ -12,7 +10,7 @@ const ElasticArraySARTTraces = Traces{
}
}

function ElasticArraySARTTraces(;
function ElasticArraySARTSATraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
Expand All @@ -31,10 +29,3 @@ function ElasticArraySARTTraces(;
)
end

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
30 changes: 30 additions & 0 deletions src/common/ElasticArraySARTSTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export ElasticArraySARTSTraces

const ElasticArraySARTSTraces = Traces{
SS′ART,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySARTSTraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
Traces(
action = ElasticArray{action_eltype}(undef, action_size..., 0),
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
35 changes: 35 additions & 0 deletions src/common/ElasticArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
export ElasticArraySLARTTraces

const ElasticArraySLARTTraces = Traces{
SS′LL′AA′RT,
<:Tuple{
<:MultiplexTraces{SS′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{LL′,<:Trace{<:ElasticArray}},
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
}

function ElasticArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
5 changes: 4 additions & 1 deletion src/common/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ include("CircularArraySARTSTraces.jl")
include("CircularArraySARTSATraces.jl")
include("CircularArraySLARTTraces.jl")
include("CircularPrioritizedTraces.jl")
include("ElasticArraySARTTraces.jl")
include("common_elastic_array.jl")
include("ElasticArraySARTSTraces.jl")
include("ElasticArraySARTSATraces.jl")
include("ElasticArraySLARTTraces.jl")
9 changes: 9 additions & 0 deletions src/common/common_elastic_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using ElasticArrays: ElasticArray, resize_lastdim!

#####
# extensions for ElasticArrays
#####

Base.push!(a::ElasticArray, x) = append!(a, x)
Base.push!(a::ElasticArray{T,1}, x) where {T} = append!(a, [x])
Base.empty!(a::ElasticArray) = resize_lastdim!(a, 0)
49 changes: 30 additions & 19 deletions src/episodes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export EpisodesBuffer, PartialNamedTuple
import DataStructures.CircularBuffer
using ElasticArrays: ElasticArray, ElasticVector

"""
EpisodesBuffer(traces::AbstractTraces)
Expand Down Expand Up @@ -68,12 +69,20 @@ function EpisodesBuffer(traces::AbstractTraces)
end
end

Base.getindex(es::EpisodesBuffer, idx...) = getindex(es.traces, idx...)
Base.setindex!(es::EpisodesBuffer, idx...) = setindex!(es.traces, idx...)
Base.size(es::EpisodesBuffer) = size(es.traces)
Base.length(es::EpisodesBuffer) = length(es.traces)
Base.keys(es::EpisodesBuffer) = keys(es.traces)
Base.keys(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(es.traces.traces)
function Base.getindex(es::EpisodesBuffer, idx::Int...)
@boundscheck all(es.sampleable_inds[idx...])
getindex(es.traces, idx...)
end

function Base.getindex(es::EpisodesBuffer, idx...)
getindex(es.traces, idx...)
end

Base.setindex!(eb::EpisodesBuffer, idx...) = setindex!(eb.traces, idx...)
Base.size(eb::EpisodesBuffer) = size(eb.traces)
Base.length(eb::EpisodesBuffer) = length(eb.traces)
Base.keys(eb::EpisodesBuffer) = keys(eb.traces)
Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces)
function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names}
s = nameof(typeof(eb))
t = eb.traces
Expand All @@ -82,14 +91,16 @@ function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where
end

ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps.
ispartial_insert(es::EpisodesBuffer, xs) = ispartial_insert(es.traces, xs)
ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs)
ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs)

function pad!(trace::Trace)
pad!(trace.parent)
return nothing
end

pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T))
pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T))
pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))

Expand Down Expand Up @@ -123,9 +134,9 @@ pad!(vect::Vector{T}) where {T} = push!(vect, zero(T))
return :($ex)
end

fill_multiplex(es::EpisodesBuffer) = fill_multiplex(es.traces)
fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)

fill_multiplex(es::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(es.traces.traces)
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)

function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
push!(eb.traces, xs)
Expand Down Expand Up @@ -162,17 +173,17 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
end

for f in (:pop!, :popfirst!)
@eval function Base.$f(es::EpisodesBuffer)
$f(es.episodes_lengths)
$f(es.sampleable_inds)
$f(es.step_numbers)
$f(es.traces)
@eval function Base.$f(eb::EpisodesBuffer)
$f(eb.episodes_lengths)
$f(eb.sampleable_inds)
$f(eb.step_numbers)
$f(eb.traces)
end
end

function Base.empty!(es::EpisodesBuffer)
empty!(es.traces)
empty!(es.episodes_lengths)
empty!(es.sampleable_inds)
empty!(es.step_numbers)
function Base.empty!(eb::EpisodesBuffer)
empty!(eb.traces)
empty!(eb.episodes_lengths)
empty!(eb.sampleable_inds)
empty!(eb.step_numbers)
end
2 changes: 2 additions & 0 deletions src/traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export Trace, Traces, MultiplexTraces
import MacroTools: @forward

import CircularArrayBuffers.CircularArrayBuffer
using ElasticArrays: ElasticArray
import Adapt

#####
Expand Down Expand Up @@ -55,6 +56,7 @@ Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s
capacity(t::AbstractTrace) = ReinforcementLearningTrajectories.capacity(t.parent)
capacity(t::CircularArrayBuffer) = CircularArrayBuffers.capacity(t)
capacity(::AbstractVector) = Inf
capacity(::ElasticArray) = Inf

#####

Expand Down
8 changes: 4 additions & 4 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ end
@test batch.terminal == Bool[0, 0, 0] |> gpu
end

@testset "ElasticArraySARTTraces" begin
t = ElasticArraySARTTraces(;
@testset "ElasticArraySARTSTraces" begin
t = ElasticArraySARTSTraces(;
state=Float32 => (2, 3),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)

@test t isa ElasticArraySARTTraces
@test t isa ElasticArraySARTSTraces

push!(t, (state=ones(Float32, 2, 3), action=1))
push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2))
Expand Down Expand Up @@ -185,4 +185,4 @@ end

eb[:priority, [1, 2]] = [0, 0]
@test eb[:priority] == [zeros(2);ones(8)]
end
end
Loading
Loading