From cd77d25854fca8059d6b54697ff8af9ff0bf3649 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 28 Aug 2019 18:30:17 +0800 Subject: [PATCH] Api change (#19) * unify interfaces * update screen after interact * add minor comment * add version check * update docker * revert to juali v1.2 due to https://github.com/JuliaLang/julia/pull/32408\#issuecomment-522168938 * update README --- .travis.yml | 4 +- Dockerfile | 2 +- README.md | 6 +- src/ReinforcementLearningEnvironments.jl | 3 + src/abstractenv.jl | 26 +++++- src/environments/atari.jl | 24 +++--- src/environments/classic_control/cart_pole.jl | 11 +-- src/environments/classic_control/mdp.jl | 79 ++++++++++--------- .../classic_control/mountain_car.jl | 11 ++- src/environments/classic_control/pendulum.jl | 30 ++++--- src/environments/gym.jl | 18 +++-- src/environments/hanabi.jl | 10 ++- src/spaces/multi_continuous_space.jl | 4 +- test/environments.jl | 11 +-- 14 files changed, 146 insertions(+), 93 deletions(-) diff --git a/.travis.yml b/.travis.yml index a6826a6..6d081c1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,9 +3,7 @@ language: julia os: - linux julia: - - 1.0 - - 1.1 - - nightly + - 1.2 notifications: email: false diff --git a/Dockerfile b/Dockerfile index bea7695..9e267f8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM julia:1.1 +FROM julia:1.2 # install dependencies RUN set -eux; \ diff --git a/README.md b/README.md index eedaf6c..ab79ff2 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This package serves as a one-stop place for different kinds of reinforcement lea Install: ```julia -(v1.1) pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl +pkg> add ReinforcementLearningEnvironments ``` ## API @@ -64,11 +64,11 @@ Take the `AtariEnv` for example: 1. Install this package by: ```julia - (v1.1) pkg> add ReinforcementLearningEnvironments + pkg> add ReinforcementLearningEnvironments ``` 2. Install corresponding dependent package by: ```julia - (v1.1) pkg> add ArcadeLearningEnvironment + pkg> add ArcadeLearningEnvironment ``` 3. Using the above two packages: ```julia diff --git a/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments.jl index 087331a..6e5ba9f 100644 --- a/src/ReinforcementLearningEnvironments.jl +++ b/src/ReinforcementLearningEnvironments.jl @@ -1,5 +1,8 @@ module ReinforcementLearningEnvironments +export RLEnvs +const RLEnvs = ReinforcementLearningEnvironments + using Reexport, Requires include("abstractenv.jl") diff --git a/src/abstractenv.jl b/src/abstractenv.jl index bc49bc9..38521c7 100644 --- a/src/abstractenv.jl +++ b/src/abstractenv.jl @@ -1,4 +1,4 @@ -export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render +export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render, Observation, get_reward, get_terminal, get_state, get_legal_actions abstract type AbstractEnv end @@ -7,4 +7,26 @@ function reset! end function interact! end function action_space end function observation_space end -function render end \ No newline at end of file +function render end + +struct Observation{R, T, S, M<:NamedTuple} + reward::R + terminal::T + state::S + meta::M +end + +Observation(;reward, terminal, state, kw...) = Observation(reward, terminal, state, merge(NamedTuple(), kw)) + +get_reward(obs::Observation) = obs.reward +get_terminal(obs::Observation) = obs.terminal +get_state(obs::Observation) = obs.state +get_legal_actions(obs::Observation) = obs.meta.legal_actions + +# !!! >= julia v1.3 +if VERSION >= v"1.3.0-rc1.0" + (env::AbstractEnv)(a) = interact!(env, a) +end + +action_space(env::AbstractEnv) = env.action_space +observation_space(env::AbstractEnv) = env.observation_space \ No newline at end of file diff --git a/src/environments/atari.jl b/src/environments/atari.jl index 6ea8859..69f1703 100644 --- a/src/environments/atari.jl +++ b/src/environments/atari.jl @@ -2,19 +2,17 @@ using ArcadeLearningEnvironment, GR export AtariEnv -struct AtariEnv{To,F} <: AbstractEnv +mutable struct AtariEnv{To,F} <: AbstractEnv ale::Ptr{Nothing} screen::Array{UInt8, 1} getscreen!::F - actions::Array{Int32, 1} + actions::Array{Int64, 1} action_space::DiscreteSpace{Int} observation_space::To noopmax::Int + reward::Float32 end -action_space(env::AtariEnv) = env.action_space -observation_space(env::AtariEnv) = env.observation_space - """ AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20, color_averaging = true, repeat_action_probability = 0.) @@ -51,24 +49,26 @@ function AtariEnv(name; end actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale) action_space = DiscreteSpace(length(actions)) - AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax) + AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax, 0.0f0) end function interact!(env::AtariEnv, a) - r = act(env.ale, env.actions[a]) + env.reward = act(env.ale, env.actions[a]) env.getscreen!(env.ale, env.screen) - (observation=env.screen, reward=r, isdone=game_over(env.ale)) + nothing end -function observe(env::AtariEnv) - env.getscreen!(env.ale, env.screen) - (observation=env.screen, isdone=game_over(env.ale)) -end +observe(env::AtariEnv) = Observation( + reward = env.reward, + terminal = game_over(env.ale), + state = env.screen +) function reset!(env::AtariEnv) reset_game(env.ale) for _ in 1:rand(0:env.noopmax) act(env.ale, Int32(0)) end env.getscreen!(env.ale, env.screen) + env.reward = 0.0f0 # dummy nothing end diff --git a/src/environments/classic_control/cart_pole.jl b/src/environments/classic_control/cart_pole.jl index e2e9079..6f8ebf8 100644 --- a/src/environments/classic_control/cart_pole.jl +++ b/src/environments/classic_control/cart_pole.jl @@ -42,9 +42,6 @@ function CartPoleEnv(; T = Float64, gravity = T(9.8), masscart = T(1.), cp end -action_space(env::CartPoleEnv) = env.action_space -observation_space(env::CartPoleEnv) = env.observation_space - function reset!(env::CartPoleEnv{T}) where T <: Number env.state[:] = T(.1) * rand(env.rng, T, 4) .- T(.05) env.t = 0 @@ -53,7 +50,11 @@ function reset!(env::CartPoleEnv{T}) where T <: Number nothing end -observe(env::CartPoleEnv) = (observation=env.state, isdone=env.done) +observe(env::CartPoleEnv) = Observation( + reward = env.done ? 0.0 : 1.0, + terminal = env.done, + state = env.state +) function interact!(env::CartPoleEnv{T}, a) where T <: Number env.action = a @@ -76,7 +77,7 @@ function interact!(env::CartPoleEnv{T}, a) where T <: Number env.done = abs(env.state[1]) > env.params.xthreshold || abs(env.state[3]) > env.params.thetathreshold || env.t >= env.params.max_steps - (observation=env.state, reward=1., isdone=env.done) + nothing end function plotendofepisode(x, y, d) diff --git a/src/environments/classic_control/mdp.jl b/src/environments/classic_control/mdp.jl index 2ff10e4..3c9ef5a 100644 --- a/src/environments/classic_control/mdp.jl +++ b/src/environments/classic_control/mdp.jl @@ -9,46 +9,52 @@ export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochas ##### POMDPEnv ##### -mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG} +mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG} <: AbstractEnv model::T state::Ts actions::Ta action_space::DiscreteSpace observation_space::DiscreteSpace + observation::Int + reward::Float64 rng::R end -POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv( - model, - initialstate(model, rng), - actions(model), - DiscreteSpace(n_actions(model)), - DiscreteSpace(n_states(model)), - rng) +function POMDPEnv(model; rng=Random.GLOBAL_RNG) + state = initialstate(model, rng) + as = DiscreteSpace(n_actions(model)) + os = DiscreteSpace(n_states(model)) + actions_of_model = actions(model) + s, o, r = generate_sor(model, state, actions_of_model[rand(as)], rng) + obs = observationindex(model, o) + POMDPEnv(model, state, actions_of_model, as, os, obs, 0., rng) +end function interact!(env::POMDPEnv, action) s, o, r = generate_sor(env.model, env.state, env.actions[action], env.rng) env.state = s - (observation = observationindex(env.model, o), - reward = r, - isdone = isterminal(env.model, s)) + env.reward = r + env.observation = observationindex(env.model, o) + nothing end -function observe(env::POMDPEnv) - (observation = observationindex(env.model, generate_o(env.model, env.state, env.rng)), - isdone = isterminal(env.model, env.state)) -end +observe(env::POMDPEnv) = Observation( + reward = env.reward, + terminal = isterminal(env.model, env.state), + state = env.observation +) ##### ##### MDPEnv ##### -mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG} +mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG} <: AbstractEnv model::T state::Ts actions::Ta action_space::DiscreteSpace observation_space::DiscreteSpace + reward::Float64 rng::R end @@ -58,10 +64,9 @@ MDPEnv(model; rng=Random.GLOBAL_RNG) = MDPEnv( actions(model), DiscreteSpace(n_actions(model)), DiscreteSpace(n_states(model)), - rng) - -action_space(env::Union{MDPEnv, POMDPEnv}) = env.action_space -observation_space(env::Union{MDPEnv, POMDPEnv}) = env.observation_space + 0., + rng +) observationindex(env, o) = Int(o) + 1 @@ -74,15 +79,15 @@ function interact!(env::MDPEnv, action) s = rand(env.rng, transition(env.model, env.state, env.actions[action])) r = POMDPs.reward(env.model, env.state, env.actions[action]) env.state = s - (observation = stateindex(env.model, s), - reward = r, - isdone = isterminal(env.model, s)) + env.reward = r + nothing end -function observe(env::MDPEnv) - (observation = stateindex(env.model, env.state), - isdone = isterminal(env.model, env.state)) -end +observe(env::MDPEnv) = Observation( + reward = env.reward, + terminal = isterminal(env.model, env.state), + state = stateindex(env.model, env.state) +) ##### ##### SimpleMDPEnv @@ -107,7 +112,7 @@ probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref [`NormalStateActionReward`](@ref)), array of initial states `initialstates`, and `ns` - array of 0/1 indicating if a state is terminal. """ -mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG} +mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG} <: AbstractEnv observation_space::DiscreteSpace action_space::DiscreteSpace state::Int @@ -115,6 +120,7 @@ mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG} reward::R initialstates::Array{Int, 1} isterminal::Array{Int, 1} + score::Float64 rng::S end @@ -125,12 +131,9 @@ function SimpleMDPEnv(ospace, aspace, state, trans_probs::Array{T, 2}, reward = DeterministicStateActionReward(reward) end SimpleMDPEnv{T,typeof(reward),S}(ospace, aspace, state, trans_probs, - reward, initialstates, isterminal, rng) + reward, initialstates, isterminal, 0., rng) end -observation_space(env::SimpleMDPEnv) = env.observation_space -action_space(env::SimpleMDPEnv) = env.action_space - # reward types """ struct DeterministicNextStateReward @@ -208,13 +211,15 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int, 1}) = run!(mdp, policy[mdp.state]) function interact!(env::SimpleMDPEnv, action) oldstate = env.state run!(env, action) - r = reward(env.rng, env.reward, oldstate, action, env.state) - (observation = env.state, reward = r, isdone = env.isterminal[env.state] == 1) + env.score = reward(env.rng, env.reward, oldstate, action, env.state) + nothing end -function observe(env::SimpleMDPEnv) - (observation = env.state, isdone = env.isterminal[env.state] == 1) -end +observe(env::SimpleMDPEnv) = Observation( + reward = env.score, + terminal = env.isterminal[env.state] == 1, + state = env.state +) function reset!(env::SimpleMDPEnv) env.state = rand(env.rng, env.initialstates) diff --git a/src/environments/classic_control/mountain_car.jl b/src/environments/classic_control/mountain_car.jl index ad11ae6..8538ec2 100644 --- a/src/environments/classic_control/mountain_car.jl +++ b/src/environments/classic_control/mountain_car.jl @@ -50,11 +50,14 @@ function MountainCarEnv(; T = Float64, continuous = false, reset!(env) env end + ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwargs...) -action_space(env::MountainCarEnv) = env.action_space -observation_space(env::MountainCarEnv) = env.observation_space -observe(env::MountainCarEnv) = (observation=env.state, isdone=env.done) +observe(env::MountainCarEnv) = Observation( + reward = env.done ? 0. : -1., + terminal = env.done, + state = env.state +) function reset!(env::MountainCarEnv{A, T}) where {A, T} env.state[1] = .2 * rand(env.rng, T) - .6 @@ -78,7 +81,7 @@ function _interact!(env::MountainCarEnv, force) env.t >= env.params.max_steps env.state[1] = x env.state[2] = v - (observation=env.state, reward=-1., isdone=env.done) + nothing end # adapted from https://github.com/JuliaML/Reinforce.jl/blob/master/src/envs/mountain_car.jl diff --git a/src/environments/classic_control/pendulum.jl b/src/environments/classic_control/pendulum.jl index 11bcb04..0f11d3a 100644 --- a/src/environments/classic_control/pendulum.jl +++ b/src/environments/classic_control/pendulum.jl @@ -20,31 +20,42 @@ mutable struct PendulumEnv{T, R<:AbstractRNG} <: AbstractEnv done::Bool t::Int rng::R + reward::T end function PendulumEnv(; T = Float64, max_speed = T(8), max_torque = T(2), g = T(10), m = T(1), l = T(1), dt = T(.05), max_steps = 200) high = T.([1, 1, max_speed]) - env = PendulumEnv(PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps), - ContinuousSpace(-2., 2.), - MultiContinuousSpace(-high, high), - zeros(T, 2), false, 0, Random.GLOBAL_RNG) + env = PendulumEnv( + PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps), + ContinuousSpace(-2., 2.), + MultiContinuousSpace(-high, high), + zeros(T, 2), + false, + 0, + Random.GLOBAL_RNG, + zero(T) + ) reset!(env) env end -action_space(env::PendulumEnv) = env.action_space -observation_space(env::PendulumEnv) = env.observation_space - pendulum_observation(s) = [cos(s[1]), sin(s[1]), s[2]] angle_normalize(x) = ((x + pi) % (2*pi)) - pi -observe(env::PendulumEnv) = (observation=pendulum_observation(env.state), isdone=env.done) +function observe(env::PendulumEnv) + Observation( + reward = env.reward, + state = pendulum_observation(env.state), + terminal = env.done + ) +end function reset!(env::PendulumEnv{T}) where T env.state[:] = 2 * rand(env.rng, T, 2) .- 1 env.t = 0 env.done = false + env.reward = zero(T) nothing end @@ -60,5 +71,6 @@ function interact!(env::PendulumEnv, a) env.state[1] = th env.state[2] = newthdot env.done = env.t >= env.params.max_steps - (observation=pendulum_observation(env.state), reward=-costs, isdone=env.done) + env.reward = -costs + nothing end \ No newline at end of file diff --git a/src/environments/gym.jl b/src/environments/gym.jl index 7dfbb95..6c00670 100644 --- a/src/environments/gym.jl +++ b/src/environments/gym.jl @@ -34,13 +34,9 @@ function GymEnv(name::String) env end -action_space(env::GymEnv) = env.action_space -observation_space(env::GymEnv) = env.observation_space - function interact!(env::GymEnv{T}, action) where T pycall!(env.state, env.pyenv.step, PyObject, action) - obs, reward, isdone, info = convert(Tuple{T, Float64, Bool, PyDict}, env.state) - (observation=obs, reward=reward, isdone=isdone) + nothing end function reset!(env::GymEnv) @@ -51,10 +47,18 @@ end function observe(env::GymEnv{T}) where T if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) obs, reward, isdone, info = convert(Tuple{T, Float64, Bool, PyDict}, env.state) - (observation=obs, isdone=isdone) + Observation( + reward = reward, + terminal = isdone, + state = obs + ) else # env has just been reseted - (observation=convert(T, env.state), isdone=false) + Observation( + reward = 0., # dummy + terminal = false, + state=convert(T, env.state) + ) end end diff --git a/src/environments/hanabi.jl b/src/environments/hanabi.jl index 56ea767..7a59156 100644 --- a/src/environments/hanabi.jl +++ b/src/environments/hanabi.jl @@ -200,10 +200,12 @@ function observe(env::HanabiEnv, observer=state_cur_player(env.state)) observation_finalizer(raw_obs) new_observation(env.state, observer, raw_obs) - (observation = raw_obs, - reward = env.reward.player == observer ? env.reward.score_gain : Int32(0), - isdone = state_end_of_game_status(env.state) != Int(NOT_FINISHED), - game = env.game) + Observation( + reward = env.reward.player == observer ? env.reward.score_gain : Int32(0), + terminal = state_end_of_game_status(env.state) != Int(NOT_FINISHED), + state = raw_obs, + game = env.game + ) end function encode_observation(obs, env) diff --git a/src/spaces/multi_continuous_space.jl b/src/spaces/multi_continuous_space.jl index 583bd30..94c9654 100644 --- a/src/spaces/multi_continuous_space.jl +++ b/src/spaces/multi_continuous_space.jl @@ -17,4 +17,6 @@ MultiContinuousSpace(low, high) = MultiContinuousSpace(convert(Array{Float64}, l Base.eltype(::MultiContinuousSpace{S, N}) where {S, N} = Array{Float64, N} Base.in(xs, s::MultiContinuousSpace{S, N}) where {S, N} = size(xs) == S && all(l <= x <= h for (l, x, h) in zip(s.low, xs, s.high)) Base.:(==)(s1::MultiContinuousSpace, s2::MultiContinuousSpace) = s1.low == s2.low && s1.high == s2.high -Base.rand(rng::AbstractRNG, s::MultiContinuousSpace) = map((l, h) -> rand(rng, Uniform(l, h)), s.low, s.high) \ No newline at end of file +Base.rand(rng::AbstractRNG, s::MultiContinuousSpace) = map((l, h) -> rand(rng, Uniform(l, h)), s.low, s.high) +Base.size(s::MultiContinuousSpace) = size(s.low) +Base.length(s::MultiContinuousSpace) = length(s.low) \ No newline at end of file diff --git a/test/environments.jl b/test/environments.jl index 259f1b4..5bfa352 100644 --- a/test/environments.jl +++ b/test/environments.jl @@ -10,9 +10,10 @@ for _ in 1:n a = rand(as) @test a in as - obs, reward, isdone = interact!(env, a) - @test obs in os - if isdone + @test interact!(env, a) === nothing + obs = observe(env) + @test get_state(obs) in os + if get_terminal(obs) reset!(env) end end @@ -24,8 +25,8 @@ for _ in 1:n a = rand(legal_actions(env)) interact!(env, a) - obs, reward, isdone = observe(env) - if isdone + obs = observe(env) + if get_terminal(obs) reset!(env) end end