Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
Simplify code structure (#112)
Browse files Browse the repository at this point in the history
* keep only interfaces here

* minor improvement

* remove unused test cases

* update readme
  • Loading branch information
findmyway authored Dec 18, 2020
1 parent 9e60c91 commit b1abf25
Show file tree
Hide file tree
Showing 28 changed files with 33 additions and 1,399 deletions.
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ version = "0.9.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.3"
CommonRLInterface = "0.2"
IntervalSets = "0.5"
MacroTools = "0.5"
julia = "1.3"
49 changes: 9 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,12 @@

[![Build Status](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl.svg?branch=master)](https://travis-ci.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl)

ReinforcementLearningBase.jl holds the common types and utility functions to be
shared by other components in ReinforcementLearning ecosystem.


## Examples

<table>
<th colspan="2">Traits</th><th> 1 </th><th> 2 </th><th> 3 </th><th> 4 </th><th> 5 </th><th> 6 </th><th> 7 </th><th> 8 </th><th> 9 </th><tr> <th rowspan="2"> ActionStyle </th><th> MinimalActionSet </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> FullActionSet </th><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> ChanceStyle </th><th> Stochastic </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Deterministic </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ExplicitStochastic </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> DefaultStateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> DynamicStyle </th><th> Simultaneous </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> Sequential </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th rowspan="2"> InformationStyle </th><th> PerfectInformation </th><td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> ImperfectInformation </th><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th rowspan="2"> NumAgentStyle </th><th> MultiAgent </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> SingleAgent </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="2"> RewardStyle </th><th> TerminalReward </th><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td></tr>
<tr> <th> StepReward </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="3"> StateStyle </th><th> Observation </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> InformationSet </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> ✔ </td></tr>
<tr> <th> InternalState </th><td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th rowspan="4"> UtilityStyle </th><th> GeneralSum </th><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> </tr>
<tr> <th> ZeroSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> ✔ </td><td> </td> <td> </td> <td> ✔ </td></tr>
<tr> <th> ConstantSum </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> </tr>
<tr> <th> IdenticalUtility </th><td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> </td> <td> ✔ </td><td> </td> <td> </td> </tr>
</table>
<ol><li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MultiArmBanditsEnv.jl"> MultiArmBanditsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RandomWalk1D.jl"> RandomWalk1D </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TigerProblemEnv.jl"> TigerProblemEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/MontyHallEnv.jl"> MontyHallEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/RockPaperScissorsEnv.jl"> RockPaperScissorsEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TicTacToeEnv.jl"> TicTacToeEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/TinyHanabiEnv.jl"> TinyHanabiEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/PigEnv.jl"> PigEnv </a></li>
<li> <a href="https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/KuhnPokerEnv.jl"> KuhnPokerEnv </a></li>
</ol>
This package defines two core concepts in reinforcement learning:

- `AbstractEnv`.
- Checkout
[ReinforcementLearningEnvironments.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl)
for versatile varieties of environments.
- `AbstractPolicy`.
[ReinforcementLearningCore.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningCore.jl)
is a good start point for how to write customized policies.
1 change: 0 additions & 1 deletion src/ReinforcementLearningBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@ include("inline_export.jl")
include("interface.jl")
include("CommonRLInterface.jl")
include("base.jl")
include("examples/examples.jl")

end # module
221 changes: 15 additions & 206 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,213 +191,22 @@ function test_interfaces!(env)
reset!(env)
end

#####
# Generate README
#####

gen_traits_table(envs) = gen_traits_table(stdout, envs)

function gen_traits_table(io, envs)
trait_dict = Dict()
for f in env_traits()
for env in envs
if !haskey(trait_dict, f)
trait_dict[f] = Set()
end
t = f(env)
if f == StateStyle
if t isa Tuple
for x in t
push!(trait_dict[f], nameof(typeof(x)))
end
else
push!(trait_dict[f], nameof(typeof(t)))
end
else
push!(trait_dict[f], nameof(typeof(t)))
end
end
end

println(io, "<table>")

print(io, "<th colspan=\"2\">Traits</th>")
for i in 1:length(envs)
print(io, "<th> $(i) </th>")
end

for k in sort(collect(keys(trait_dict)), by = nameof)
vs = trait_dict[k]
print(io, "<tr> <th rowspan=\"$(length(vs))\"> $(nameof(k)) </th>")
for (i, v) in enumerate(vs)
if i != 1
print(io, "<tr> ")
end
print(io, "<th> $(v) </th>")
for env in envs
if k == StateStyle && k(env) isa Tuple
ss = k(env)
if v in map(x -> nameof(typeof(x)), ss)
print(io, "<td> ✔ </td>")
else
print(io, "<td> </td> ")
end
else
if nameof(typeof(k(env))) == v
print(io, "<td> ✔ </td>")
else
print(io, "<td> </td> ")
end
end
end
println(io, "</tr>")
end
end

println(io, "</table>")

print(io, "<ol>")
for env in envs
println(
io,
"<li> <a href=\"https://github.com/JuliaReinforcementLearning/ReinforcementLearningBase.jl/tree/master/src/examples/$(nameof(env)).jl\"> $(nameof(env)) </a></li>",
)
end
print(io, "</ol>")
end

#####
# Utils
#####

using IntervalSets

Random.rand(s::Union{Interval,Array{<:Interval}}) = rand(Random.GLOBAL_RNG, s)

function Random.rand(rng::AbstractRNG, s::Interval)
rand(rng) * (s.right - s.left) + s.left
end

#####
# WorldSpace
#####

export WorldSpace

"""
In some cases, we may not be interested in the action/state space.
One can return `WorldSpace()` to keep the interface consistent.
"""
struct WorldSpace{T} end

WorldSpace() = WorldSpace{Any}()

Base.in(x, ::WorldSpace{T}) where {T} = x isa T

#####
# ZeroTo
#####

export ZeroTo

"""
Similar to `Base.OneTo`. Useful when wrapping third-party environments.
"""
struct ZeroTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
ZeroTo{T}(n) where {T<:Integer} = new(max(zero(T) - one(T), n))
end

ZeroTo(n::T) where {T<:Integer} = ZeroTo{T}(n)

Base.show(io::IO, r::ZeroTo) = print(io, "ZeroTo(", r.stop, ")")
Base.length(r::ZeroTo{T}) where {T} = T(r.stop + one(r.stop))
Base.first(r::ZeroTo{T}) where {T} = zero(r.stop)

function getindex(v::ZeroTo{T}, i::Integer) where {T}
Base.@_inline_meta
@boundscheck ((i >= 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

#####
# ActionProbPair
#####

export ActionProbPair

"""
Used in action space of chance player.
"""
struct ActionProbPair{A,P}
action::A
prob::P
end

"""
Directly copied from [StatsBase.jl](https://github.com/JuliaStats/StatsBase.jl/blob/0ea8e798c3d19609ed33b11311de5a2bd6ee9fd0/src/sampling.jl#L499-L510) to avoid depending on the whole package.
Here we assume `wv` sum to `1`
"""
function weighted_sample(rng::AbstractRNG, wv)
t = rand(rng)
cw = zero(Base.first(wv))
for (i, w) in enumerate(wv)
cw += w
if cw >= t
return i
end
end
end

Random.rand(rng::AbstractRNG, s::AbstractVector{<:ActionProbPair}) =
s[weighted_sample(rng, (x.prob for x in s))]

(env::AbstractEnv)(a::ActionProbPair) = env(a.action)

#####
# Space
#####

export Space

"""
A wrapper to treat each element as a sub-space which supports `Random.rand` and `Base.in`.
"""
struct Space{T}
s::T
end

Random.rand(s::Space) = rand(Random.GLOBAL_RNG, s)

Random.rand(rng::AbstractRNG, s::Space) =
map(s.s) do x
rand(rng, x)
end

Random.rand(rng::AbstractRNG, s::Space{<:Dict}) = Dict(k => rand(rng, v) for (k, v) in s.s)

function Base.in(X, S::Space)
if length(X) == length(S.s)
for (x, s) in zip(X, S.s)
if x s
return false
end
end
return true
else
return false
end
end

function Base.in(X::Dict, S::Space{<:Dict})
if keys(X) == keys(S.s)
for k in keys(X)
if X[k] S.s[k]
return false
function test_runnable!(env, n = 1000;rng=Random.GLOBAL_RNG)
@testset "random policy with $(nameof(env))" begin
reset!(env)
for _ in 1:n
A = legal_action_space(env)
a = rand(rng, A)
@test a in A

S = state_space(env)
s = state(env)
@test s in S
env(a)
if is_terminated(env)
reset!(env)
end
end
return true
else
return false
reset!(env)
end
end
Loading

0 comments on commit b1abf25

Please sign in to comment.