From 461dae1766d2289f389730fbd4a41e0846d7f54c Mon Sep 17 00:00:00 2001 From: Johanni Brea Date: Fri, 7 Sep 2018 15:29:28 +0200 Subject: [PATCH 1/4] fix examples --- Project.toml | 2 +- examples/cartpole.jl | 2 +- examples/cartpoleDQN.jl | 6 +++--- src/ReinforcementLearningEnvironmentGym.jl | 14 ++++++++++---- test/runtests.jl | 10 ++-------- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 05b8b87..558a41a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,7 @@ version = "0.1.0" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" +ReinforcementLearningBase = "9b2b9cba-ac73-11e8-02b1-9f0869453fc0" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/examples/cartpole.jl b/examples/cartpole.jl index f9bfa70..4d30b35 100644 --- a/examples/cartpole.jl +++ b/examples/cartpole.jl @@ -1,5 +1,5 @@ # Using PyCall is rather slow. Please compare to https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironmentClassicControl.jl/blob/master/examples/cartpole.jl -using ReinforcementLearningEnvironmentGym +using ReinforcementLearningEnvironmentGym, ReinforcementLearning env = GymEnv("CartPole-v0") rlsetup = RLSetup(ActorCriticPolicyGradient(ns = 4, na = 2, α = .02, diff --git a/examples/cartpoleDQN.jl b/examples/cartpoleDQN.jl index 8e00928..1cec808 100644 --- a/examples/cartpoleDQN.jl +++ b/examples/cartpoleDQN.jl @@ -1,4 +1,4 @@ -using ReinforcementLearningEnvironmentGym, Flux +using ReinforcementLearningEnvironmentGym, Flux, ReinforcementLearning # List all envs listallenvs() @@ -14,12 +14,12 @@ learner = DQN(Chain(Dense(4, 48, relu), Dense(48, 24, relu), Dense(24, 2)), x = RLSetup(learner, env, ConstantNumberEpisodes(10), callbacks = [Progress(), EvaluationPerEpisode(TimeSteps()), Visualize(wait = 0)]) -info("Before learning.") +@info("Before learning.") run!(x) pop!(x.callbacks) x.stoppingcriterion = ConstantNumberEpisodes(400) @time learn!(x) x.stoppingcriterion = ConstantNumberEpisodes(10) push!(x.callbacks, Visualize(wait = 0)) -info("After learning.") +@info("After learning.") run!(x) diff --git a/src/ReinforcementLearningEnvironmentGym.jl b/src/ReinforcementLearningEnvironmentGym.jl index 4165c8f..aac26fa 100644 --- a/src/ReinforcementLearningEnvironmentGym.jl +++ b/src/ReinforcementLearningEnvironmentGym.jl @@ -1,5 +1,4 @@ module ReinforcementLearningEnvironmentGym -export GymEnv, listallenvs, interact!, reset!, getstate, plotenv, actionspace, sample using ReinforcementLearningBase import ReinforcementLearningBase:interact!, reset!, getstate, plotenv, actionspace using PyCall @@ -34,13 +33,19 @@ function GymEnv(name::String) obsspace = gymspace2jlspace(pyenv[:observation_space]) actspace = gymspace2jlspace(pyenv[:action_space]) state = PyNULL() - GymEnv(pyenv, obsspace, actspace, state) + env = GymEnv(pyenv, obsspace, actspace, state) + reset!(env) # state needs to be set to call defaultbuffer in RL + env end function interact!(env::GymEnv, action) pycall!(env.state, env.pyobj[:step], PyVector, action) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end +function interact!(env::GymEnv, action::Int) + pycall!(env.state, env.pyobj[:step], PyVector, action - 1) + (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) +end "Not very useful, kept for compat" function getstate(env::GymEnv) @@ -48,11 +53,11 @@ function getstate(env::GymEnv) (observation=env.state[1], isdone=env.state[3]) else # env has just been reseted - (observation=env.state, isdone=false) + (observation=Float64.(env.state), isdone=false) end end -reset!(env::GymEnv) = (observation=pycall!(env.state, env.pyobj[:reset], PyArray),) +reset!(env::GymEnv) = (observation=Float64.(pycall!(env.state, env.pyobj[:reset], PyArray)),) plotenv(env::GymEnv) = env.pyobj[:render]() actionspace(env::GymEnv) = env.actionspace @@ -71,4 +76,5 @@ function listallenvs(pattern = r"") end end +export GymEnv, listallenvs end # module diff --git a/test/runtests.jl b/test/runtests.jl index 51098e2..7795a11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,3 @@ -using ReinforcementLearningEnvironmentGym -using PyCall -using Test +using ReinforcementLearningEnvironmentGym, Test, ReinforcementLearningBase -for x in ["CartPole-v0"] - env = GymEnv(x) - @test typeof(reset!(env)) == NamedTuple{(:observation,), Tuple{PyArray{Float64, 1}}} - @test typeof(interact!(env, 1)) == NamedTuple{(:observation, :reward, :isdone), Tuple{Array{Float64, 1}, Float64, Bool}} -end +test_envinterface(GymEnv("CartPole-v0")) From 34cd3e87b4e1440bc4decc228e2b71206f095b92 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 8 Sep 2018 14:39:44 +0800 Subject: [PATCH 2/4] no need to modify the deps id by hand --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 558a41a..05b8b87 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,7 @@ version = "0.1.0" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -ReinforcementLearningBase = "9b2b9cba-ac73-11e8-02b1-9f0869453fc0" +ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 99a225dd4b82284d656f66892a650042a641b2fe Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 8 Sep 2018 16:43:49 +0800 Subject: [PATCH 3/4] fix index issue --- src/ReinforcementLearningEnvironmentGym.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/ReinforcementLearningEnvironmentGym.jl b/src/ReinforcementLearningEnvironmentGym.jl index aac26fa..1a55ab0 100644 --- a/src/ReinforcementLearningEnvironmentGym.jl +++ b/src/ReinforcementLearningEnvironmentGym.jl @@ -12,19 +12,19 @@ end function gymspace2jlspace(s::PyObject) spacetype = s[:__class__][:__name__] if spacetype == "Box" BoxSpace(s[:low], s[:high]) - elseif spacetype == "Discrete" DiscreteSpace(s[:n], 0) + elseif spacetype == "Discrete" DiscreteSpace(s[:n], 1) elseif spacetype == "MultiBinary" MultiBinarySpace(s[:n]) - elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 0) + elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 1) elseif spacetype == "Tuple" map(gymspace2jlspace, s[:spaces]) elseif spacetype == "Dict" Dict(map((k, v) -> (k, gymspace2jlspace(v)), s[:spaces])) else error("Don't know how to convert [$(spacetype)]") end end -struct GymEnv <: AbstractEnv +struct GymEnv{Ta<:AbstractSpace, To<:AbstractSpace} <: AbstractEnv pyobj::PyObject - observationspace::AbstractSpace - actionspace::AbstractSpace + observationspace::To + actionspace::Ta state::PyObject end @@ -42,11 +42,17 @@ function interact!(env::GymEnv, action) pycall!(env.state, env.pyobj[:step], PyVector, action) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end -function interact!(env::GymEnv, action::Int) + +function interact!(env::GymEnv{BoxSpace}, action::Int) pycall!(env.state, env.pyobj[:step], PyVector, action - 1) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end +function interact!(env::GymEnv{MultiDiscreteSpace}, action::AbstractArray{Int}) + pycall!(env.state, env.pyobj[:step], PyVector, action .- 1) + (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) +end + "Not very useful, kept for compat" function getstate(env::GymEnv) if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) From 3db25181b552eebf52c9f28e4285e0feba196b4a Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 8 Sep 2018 17:07:36 +0800 Subject: [PATCH 4/4] minor fix --- src/ReinforcementLearningEnvironmentGym.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ReinforcementLearningEnvironmentGym.jl b/src/ReinforcementLearningEnvironmentGym.jl index 1a55ab0..789ef97 100644 --- a/src/ReinforcementLearningEnvironmentGym.jl +++ b/src/ReinforcementLearningEnvironmentGym.jl @@ -43,7 +43,7 @@ function interact!(env::GymEnv, action) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end -function interact!(env::GymEnv{BoxSpace}, action::Int) +function interact!(env::GymEnv{DiscreteSpace}, action::Int) pycall!(env.state, env.pyobj[:step], PyVector, action - 1) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end