Skip to content

Commit

Permalink
fix bugs when trying to adapt to RLEnvs.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway committed Jun 4, 2022
1 parent 8e5ef58 commit d4b705f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml
Manifest.toml
/docs/build/
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CommonRLSpaces"
uuid = "408f5b3e-f2a2-48a6-b4bb-c8aa44c458e6"
authors = ["Jun Tian <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
31 changes: 22 additions & 9 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,38 @@ struct Space{T}
s::T
end

Space(s::Type{T}) where {T} = Space(typemin(T):typemax(T))
Space(s::Type{T}) where {T<:Number} = Space(typemin(T):typemax(T))

Space(x, dims::Int...) = Space(fill(x, dims))
Space(x::Type{T}, dims::Int...) where {T<:Integer} = Space(fill(typemin(x):typemax(T), dims))
Space(x::Type{T}, dims::Int...) where {T<:AbstractFloat} = Space(fill(typemin(x) .. typemax(T), dims))
Space(x, dims::Int...) = Space(Fill(x, dims))
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T<:Integer} = Space(Fill(typemin(x):typemax(T), dim, extra_dims...))
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T<:AbstractFloat} = Space(Fill(typemin(x) .. typemax(T), dim, extra_dims...))
Space(x::Type{T}, dim::Int, extra_dims::Int...) where {T} = Space(Fill(T, dim, extra_dims...))

Base.size(s::Space) = size(SpaceStyle(s))
Base.length(s::Space) = length(SpaceStyle(s), s)
Base.getindex(s::Space, i...) = getindex(SpaceStyle(s), s, i...)
Base.:(==)(s1::Space, s2::Space) = s1.s == s2.s

#####

abstract type AbstractSpaceStyle{S} end

Base.size(::AbstractSpaceStyle{S}) where {S} = S

struct DiscreteSpaceStyle{S} <: AbstractSpaceStyle{S} end
struct ContinuousSpaceStyle{S} <: AbstractSpaceStyle{S} end

SpaceStyle(::Space{<:Tuple}) = DiscreteSpaceStyle{()}()
SpaceStyle(::Space{<:AbstractRange}) = DiscreteSpaceStyle{()}()
SpaceStyle(::Space{<:AbstractVector{<:Number}}) = DiscreteSpaceStyle{()}()
SpaceStyle(::Space{<:AbstractInterval}) = ContinuousSpaceStyle{()}()

SpaceStyle(s::Space{<:AbstractArray{<:Tuple}}) = DiscreteSpaceStyle{size(s.s)}()
SpaceStyle(s::Space{<:AbstractArray{<:AbstractRange}}) = DiscreteSpaceStyle{size(s.s)}()
SpaceStyle(s::Space{<:AbstractArray{<:AbstractInterval}}) = ContinuousSpaceStyle{size(s.s)}()

Base.size(::AbstractSpaceStyle{S}) where {S} = S
Base.length(::DiscreteSpaceStyle{()}, s) = length(s.s)
Base.getindex(::DiscreteSpaceStyle{()}, s, i...) = getindex(s.s, i...)
Base.length(::DiscreteSpaceStyle, s) = mapreduce(length, *, s.s)

#####

Random.rand(rng::Random.AbstractRNG, s::Space) = rand(rng, s.s)
Expand All @@ -45,6 +52,7 @@ Random.rand(
) = map(x -> rand(rng, x), s.s)

Base.in(x, s::Space) = x in s.s
Base.in(x, s::Space{<:Type}) = x isa s.s

Base.in(
x,
Expand All @@ -69,15 +77,20 @@ function Random.rand(rng::AbstractRNG, s::Interval{:closed,:closed,T}) where {T}
end
end

Base.iterate(s::Space, args...) = iterate(SpaceStyle(s), s, args...)
Base.iterate(::DiscreteSpaceStyle{()}, s::Space, args...) = iterate(s.s, args...)

#####

const TupleSpace = Tuple{Vararg{<:Space}}
const TupleSpace = Tuple{Vararg{Space}}
const NamedSpace = NamedTuple{<:Any,<:TupleSpace}
const VectorSpace = Vector{<:Space}
const DictSpace = Dict{<:Any,<:Space}

Random.rand(rng::AbstractRNG, s::Union{TupleSpace,NamedSpace}) = map(x -> rand(rng, x), s)
Random.rand(rng::AbstractRNG, s::Union{TupleSpace,NamedSpace,VectorSpace}) = map(x -> rand(rng, x), s)
Random.rand(rng::AbstractRNG, s::DictSpace) = Dict(k => rand(rng, s[k]) for k in keys(s))

Base.in(xs::Tuple, ts::TupleSpace) = length(xs) == length(ts) && all(((x, s),) -> x in s, zip(xs, ts))
Base.in(xs::AbstractVector, ts::VectorSpace) = length(xs) == length(ts) && all(((x, s),) -> x in s, zip(xs, ts))
Base.in(xs::NamedTuple{names}, ns::NamedTuple{names,<:TupleSpace}) where {names} = all(((x, s),) -> x in s, zip(xs, ns))
Base.in(xs::Dict, ds::DictSpace) = length(xs) == length(ds) && all(k -> haskey(ds, k) && xs[k] in ds[k], keys(xs))

0 comments on commit d4b705f

Please sign in to comment.