Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gym.jl #1009

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using .PyCall

function GymEnv(name::String; seed::Union{Int,Nothing}=nothing)
if !PyCall.pyexists("gym")
if !PyCall.pyexists("gymnasium")
error(
"Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n",
"Cannot import module 'gymnasium'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n",
)
end
gym = pyimport_conda("gym", "gym")
gym = pyimport_conda("gymnasium", "gymnasium")
if PyCall.pyexists("d4rl")
pyimport("d4rl")
end
Expand Down Expand Up @@ -69,26 +69,45 @@ RLBase.action_space(env::GymEnv) = env.action_space
RLBase.state_space(env::GymEnv) = env.observation_space

function RLBase.reward(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5
_, reward, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state)
reward
elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
_, reward, = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
reward
else
0.0
end
end

function RLBase.is_terminated(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5
_, _, isterminated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state)
isterminated
elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
@warn "Gym version outdated. Update gym to obtain termination and truncation info instead of done signal."
_, _, isdone, = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
isdone
else
false
end
end

function is_truncated(env::GymEnv{T}) where {T}
jeremiahpslewis marked this conversation as resolved.
Show resolved Hide resolved
jeremiahpslewis marked this conversation as resolved.
Show resolved Hide resolved
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5
_, _, _, istruncated, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state)
istruncated
else
false
end
end

function RLBase.state(env::GymEnv{T}) where {T}
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, reward, isdone, info = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
if pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 5
obs, = convert(Tuple{T,Float64,Bool,Bool,PyDict}, env.state)
obs
elseif pyisinstance(env.state, PyCall.@pyglobalobj :PyTuple_Type) && length(env.state) == 4
obs, = convert(Tuple{T,Float64,Bool,PyDict}, env.state)
obs
else
convert(T, env.state)
Expand Down Expand Up @@ -123,19 +142,19 @@ end

function list_gym_env_names(;
modules=[
"gym.envs.algorithmic",
"gym.envs.box2d",
"gym.envs.classic_control",
"gym.envs.mujoco",
"gym.envs.mujoco.ant_v3",
"gym.envs.mujoco.half_cheetah_v3",
"gym.envs.mujoco.hopper_v3",
"gym.envs.mujoco.humanoid_v3",
"gym.envs.mujoco.swimmer_v3",
"gym.envs.mujoco.walker2d_v3",
"gym.envs.robotics",
"gym.envs.toy_text",
"gym.envs.unittest",
"gymnasium.envs.algorithmic",
"gymnasium.envs.box2d",
"gymnasium.envs.classic_control",
"gymnasium.envs.mujoco",
"gymnasium.envs.mujoco.ant_v3",
"gymnasium.envs.mujoco.half_cheetah_v3",
"gymnasium.envs.mujoco.hopper_v3",
"gymnasium.envs.mujoco.humanoid_v3",
"gymnasium.envs.mujoco.swimmer_v3",
"gymnasium.envs.mujoco.walker2d_v3",
"gymnasium.envs.robotics",
"gymnasium.envs.toy_text",
"gymnasium.envs.unittest",
"d4rl.pointmaze",
"d4rl.hand_manipulation_suite",
"d4rl.gym_mujoco.gym_envs",
Expand All @@ -147,14 +166,14 @@ function list_gym_env_names(;
if PyCall.pyexists("d4rl")
pyimport("d4rl")
end
gym = pyimport("gym")
gym = pyimport("gymnasium")
[x.id for x in values(gym.envs.registry) if split(x.entry_point, ':')[1] in modules]
end

"""
install_gym(; packages = ["gym", "pybullet"])
install_gym(; packages = ["gymnasium", "pybullet"])
"""
function install_gym(; packages=["gym", "pybullet"])
function install_gym(; packages=["gymnasium", "pybullet"])
# Use eventual proxy info
proxy_arg = String[]
if haskey(ENV, "http_proxy")
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningEnvironments/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using OrdinaryDiffEq
using TimerOutputs
using Conda

Conda.add("gym")
Conda.add("gymnasium")
Conda.add("numpy")

@testset "ReinforcementLearningEnvironments" begin
Expand Down