Skip to content

Commit

Permalink
Merge pull request #1 from findmyway/master
Browse files Browse the repository at this point in the history
Use ReinforcementLearningBase as dep
  • Loading branch information
findmyway authored Sep 6, 2018
2 parents b075d84 + f400d11 commit 631626b
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
*.jl.cov
*.jl.*.cov
*.jl.mem

Manifest.toml
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ os:
- linux
- osx
julia:
- 0.7
- 1.0
- nightly
notifications:
Expand Down
14 changes: 14 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name = "ReinforcementLearningEnvironmentGym"
uuid = "1053412c-b184-11e8-2eb2-ef51b44fed15"
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
version = "0.1.0"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

[![codecov.io](http://codecov.io/github/JuliaReinforcementLearning/ReinforcementLearningEnvironmentGym.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaReinforcementLearning/ReinforcementLearningEnvironmentGym.jl?branch=master)

Making the [OpenAI gym](https://github.com/openai/gym) environments available to the [Julia Reinforcement Learning](https://github.com/jbrea/ReinforcementLearning.jl) package.
Making the [OpenAI gym](https://github.com/openai/gym) environments available to the [Julia Reinforcement Learning](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl) package.
5 changes: 2 additions & 3 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
julia 0.7
julia 1.0
PyCall
Reexport
ReinforcementLearning
ReinforcementLearningBase
67 changes: 36 additions & 31 deletions src/ReinforcementLearningEnvironmentGym.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module ReinforcementLearningEnvironmentGym
using Reexport
@reexport using ReinforcementLearning
import ReinforcementLearning:interact!, reset!, getstate, plotenv
export GymEnv, listallenvs, interact!, reset!, getstate, plotenv, actionspace, sample
using ReinforcementLearningBase
import ReinforcementLearningBase:interact!, reset!, getstate, plotenv, actionspace
using PyCall
const gym = PyNULL()

Expand All @@ -10,45 +10,52 @@ function __init__()
pyimport("pybullet_envs")
end

function getspace(space)
if pyisinstance(space, gym[:spaces][:box][:Box])
ReinforcementLearning.Box(space[:low], space[:high])
elseif pyisinstance(space, gym[:spaces][:discrete][:Discrete])
1:space[:n]
else
error("Don't know how to convert $(pytypeof(space)).")
function gymspace2jlspace(s::PyObject)
spacetype = s[:__class__][:__name__]
if spacetype == "Box" BoxSpace(s[:low], s[:high])
elseif spacetype == "Discrete" DiscreteSpace(s[:n], 0)
elseif spacetype == "MultiBinary" MultiBinarySpace(s[:n])
elseif spacetype == "MultiDiscrete" MultiDiscreteSpace(s[:nvec], 0)
elseif spacetype == "Tuple" map(gymspace2jlspace, s[:spaces])
elseif spacetype == "Dict" Dict(map((k, v) -> (k, gymspace2jlspace(v)), s[:spaces]))
else error("Don't know how to convert [$(spacetype)]")
end
end
struct GymEnv{TObject, TObsSpace, TActionSpace}
pyobj::TObject
observation_space::TObsSpace
action_space::TActionSpace

struct GymEnv <: AbstractEnv
pyobj::PyObject
observationspace::AbstractSpace
actionspace::AbstractSpace
state::PyObject
end

function GymEnv(name::String)
pyenv = gym[:make](name)
obsspace = getspace(pyenv[:observation_space])
actspace = getspace(pyenv[:action_space])
pyenv[:reset]()
obsspace = gymspace2jlspace(pyenv[:observation_space])
actspace = gymspace2jlspace(pyenv[:action_space])
state = PyNULL()
pycall!(state, pyenv[:step], PyVector, pyenv[:action_space][:sample]())
pyenv[:reset]()
GymEnv(pyenv, obsspace, actspace, state)
end

function interactgym!(action, env)
if env.state[3]
reset!(env)
end
function interact!(env::GymEnv, action)
pycall!(env.state, env.pyobj[:step], PyVector, action)
env.state[1], env.state[2], env.state[3]
(observation=env.state[1], reward=env.state[2], isdone=env.state[3])
end

"Not very useful, kept for compat"
function getstate(env::GymEnv)
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type)
(observation=env.state[1], isdone=env.state[3])
else
# env has just been reseted
(observation=env.state, isdone=false)
end
end
interact!(action, env::GymEnv) = interactgym!(action, env)
interact!(action::Int64, env::GymEnv) = interactgym!(action - 1, env)
reset!(env::GymEnv) = env.pyobj[:reset]()
getstate(env::GymEnv) = (env.state[1], false)

plotenv(env::GymEnv, s, a, r, d) = env.pyobj[:render]()
reset!(env::GymEnv) = (observation=pycall!(env.state, env.pyobj[:reset], PyArray),)
plotenv(env::GymEnv) = env.pyobj[:render]()
actionspace(env::GymEnv) = env.actionspace

"""
listallenvs(pattern = r"")
Expand All @@ -64,6 +71,4 @@ function listallenvs(pattern = r"")
end
end

export GymEnv, listallenvs

end # module
13 changes: 4 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
using ReinforcementLearningEnvironmentGym
@static if VERSION < v"0.7.0-DEV.2005"
using Base.Test
else
using Test
end
using PyCall
using Test

import ReinforcementLearningEnvironmentGym: reset!, interact!, getstate
for x in ["CartPole-v0"]
env = GymEnv(x)
reset!(env)
@test typeof(interact!(1, env)) == Tuple{Array{Float64, 1}, Float64, Bool}
@test typeof(getstate(env)) == Tuple{Array{Float64, 1}, Bool}
@test typeof(reset!(env)) == NamedTuple{(:observation,), Tuple{PyArray{Float64, 1}}}
@test typeof(interact!(env, 1)) == NamedTuple{(:observation, :reward, :isdone), Tuple{Array{Float64, 1}, Float64, Bool}}
end

0 comments on commit 631626b

Please sign in to comment.