diff --git a/examples/cartpole.jl b/examples/cartpole.jl index f9bfa70..4d30b35 100644 --- a/examples/cartpole.jl +++ b/examples/cartpole.jl @@ -1,5 +1,5 @@ # Using PyCall is rather slow. Please compare to https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironmentClassicControl.jl/blob/master/examples/cartpole.jl -using ReinforcementLearningEnvironmentGym +using ReinforcementLearningEnvironmentGym, ReinforcementLearning env = GymEnv("CartPole-v0") rlsetup = RLSetup(ActorCriticPolicyGradient(ns = 4, na = 2, α = .02, diff --git a/examples/cartpoleDQN.jl b/examples/cartpoleDQN.jl index 8e00928..1cec808 100644 --- a/examples/cartpoleDQN.jl +++ b/examples/cartpoleDQN.jl @@ -1,4 +1,4 @@ -using ReinforcementLearningEnvironmentGym, Flux +using ReinforcementLearningEnvironmentGym, Flux, ReinforcementLearning # List all envs listallenvs() @@ -14,12 +14,12 @@ learner = DQN(Chain(Dense(4, 48, relu), Dense(48, 24, relu), Dense(24, 2)), x = RLSetup(learner, env, ConstantNumberEpisodes(10), callbacks = [Progress(), EvaluationPerEpisode(TimeSteps()), Visualize(wait = 0)]) -info("Before learning.") +@info("Before learning.") run!(x) pop!(x.callbacks) x.stoppingcriterion = ConstantNumberEpisodes(400) @time learn!(x) x.stoppingcriterion = ConstantNumberEpisodes(10) push!(x.callbacks, Visualize(wait = 0)) -info("After learning.") +@info("After learning.") run!(x) diff --git a/src/ReinforcementLearningEnvironmentGym.jl b/src/ReinforcementLearningEnvironmentGym.jl index 4165c8f..789ef97 100644 --- a/src/ReinforcementLearningEnvironmentGym.jl +++ b/src/ReinforcementLearningEnvironmentGym.jl @@ -1,5 +1,4 @@ module ReinforcementLearningEnvironmentGym -export GymEnv, listallenvs, interact!, reset!, getstate, plotenv, actionspace, sample using ReinforcementLearningBase import ReinforcementLearningBase:interact!, reset!, getstate, plotenv, actionspace using PyCall @@ -13,19 +12,19 @@ end function gymspace2jlspace(s::PyObject) spacetype = s[:__class__][:__name__] if spacetype == "Box" BoxSpace(s[:low], s[:high]) - elseif spacetype == "Discrete" DiscreteSpace(s[:n], 0) + elseif spacetype == "Discrete" DiscreteSpace(s[:n], 1) elseif spacetype == "MultiBinary" MultiBinarySpace(s[:n]) - elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 0) + elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 1) elseif spacetype == "Tuple" map(gymspace2jlspace, s[:spaces]) elseif spacetype == "Dict" Dict(map((k, v) -> (k, gymspace2jlspace(v)), s[:spaces])) else error("Don't know how to convert [$(spacetype)]") end end -struct GymEnv <: AbstractEnv +struct GymEnv{Ta<:AbstractSpace, To<:AbstractSpace} <: AbstractEnv pyobj::PyObject - observationspace::AbstractSpace - actionspace::AbstractSpace + observationspace::To + actionspace::Ta state::PyObject end @@ -34,7 +33,9 @@ function GymEnv(name::String) obsspace = gymspace2jlspace(pyenv[:observation_space]) actspace = gymspace2jlspace(pyenv[:action_space]) state = PyNULL() - GymEnv(pyenv, obsspace, actspace, state) + env = GymEnv(pyenv, obsspace, actspace, state) + reset!(env) # state needs to be set to call defaultbuffer in RL + env end function interact!(env::GymEnv, action) @@ -42,17 +43,27 @@ function interact!(env::GymEnv, action) (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) end +function interact!(env::GymEnv{DiscreteSpace}, action::Int) + pycall!(env.state, env.pyobj[:step], PyVector, action - 1) + (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) +end + +function interact!(env::GymEnv{MultiDiscreteSpace}, action::AbstractArray{Int}) + pycall!(env.state, env.pyobj[:step], PyVector, action .- 1) + (observation=env.state[1], reward=env.state[2], isdone=env.state[3]) +end + "Not very useful, kept for compat" function getstate(env::GymEnv) if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) (observation=env.state[1], isdone=env.state[3]) else # env has just been reseted - (observation=env.state, isdone=false) + (observation=Float64.(env.state), isdone=false) end end -reset!(env::GymEnv) = (observation=pycall!(env.state, env.pyobj[:reset], PyArray),) +reset!(env::GymEnv) = (observation=Float64.(pycall!(env.state, env.pyobj[:reset], PyArray)),) plotenv(env::GymEnv) = env.pyobj[:render]() actionspace(env::GymEnv) = env.actionspace @@ -71,4 +82,5 @@ function listallenvs(pattern = r"") end end +export GymEnv, listallenvs end # module diff --git a/test/runtests.jl b/test/runtests.jl index 51098e2..7795a11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,3 @@ -using ReinforcementLearningEnvironmentGym -using PyCall -using Test +using ReinforcementLearningEnvironmentGym, Test, ReinforcementLearningBase -for x in ["CartPole-v0"] - env = GymEnv(x) - @test typeof(reset!(env)) == NamedTuple{(:observation,), Tuple{PyArray{Float64, 1}}} - @test typeof(interact!(env, 1)) == NamedTuple{(:observation, :reward, :isdone), Tuple{Array{Float64, 1}, Float64, Bool}} -end +test_envinterface(GymEnv("CartPole-v0"))