Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): adapt DingEnvWrapper to gymnasium #817

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 31 additions & 29 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import treetensor.numpy as tnp

from ding.envs.common.common_function import affine_transform
from ding.envs.env_wrappers import create_env_wrapper
from ding.envs.env_wrappers import create_env_wrapper, GymToGymnasiumWrapper
from ding.torch_utils import to_ndarray
from ding.utils import CloudPickleWrapper
from .base_env import BaseEnv, BaseEnvTimestep
Expand All @@ -23,7 +23,14 @@
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
"""

def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
def __init__(
self,
env: Union[gym.Env, gymnasium.Env] = None,
cfg: dict = None,
seed_api: bool = True,
caller: str = 'collector',
is_gymnasium: bool = False
) -> None:
"""
Overview:
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
Expand All @@ -32,17 +39,20 @@
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
The `cfg` parameter must contain `env_id`.
Arguments:
- env (:obj:`gym.Env`): An environment instance to be wrapped.
- env (:obj:`Union[gym.Env, gymnasium.Env]`): An environment instance to be wrapped.
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
- is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \
the environment is a gym environment.
"""
self._env = None
self._raw_env = env
self._cfg = cfg
self._seed_api = seed_api # some env may disable `env.seed` api
self._caller = caller

if self._cfg is None:
self._cfg = {}
self._cfg = EasyDict(self._cfg)
Expand All @@ -55,6 +65,7 @@
if 'env_id' not in self._cfg:
self._cfg.env_id = None
if env is not None:
self._is_gymnasium = isinstance(env, gymnasium.Env)
self._env = env
self._wrap_env(caller)
self._observation_space = self._env.observation_space
Expand All @@ -66,6 +77,7 @@
self._init_flag = True
else:
assert 'env_id' in self._cfg
self._is_gymnasium = is_gymnasium
self._init_flag = False
self._observation_space = None
self._action_space = None
Expand All @@ -82,7 +94,8 @@
- obs (:obj:`Dict`): The new observation after reset.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
gym_proxy = gymnasium if self._is_gymnasium else gym
self._env = gym_proxy.make(self._cfg.env_id)
self._wrap_env(self._caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
Expand All @@ -98,29 +111,16 @@
name_prefix='rl-video-{}'.format(id(self))
)
self._replay_path = None
if isinstance(self._env, gym.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
elif isinstance(self._env, gymnasium.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._action_space.seed(self._seed + np_seed)
obs = self._env.reset(seed=self._seed + np_seed)
elif hasattr(self, '_seed'):
self._action_space.seed(self._seed)
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
else:
raise RuntimeError("not support env type: {}".format(type(self._env)))
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
if self.observation_space.dtype == np.float32:
obs = to_ndarray(obs, dtype=np.float32)
else:
Expand Down Expand Up @@ -221,7 +221,7 @@
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
elif isinstance(random_action, (int, np.int64)):
random_action = to_ndarray([random_action], dtype=np.int64)
elif isinstance(random_action, dict):
random_action = to_ndarray(random_action)
Expand All @@ -241,6 +241,8 @@
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
Different caller may need different wrappers. Default is 'collector'.
"""
if self._is_gymnasium:
self._env = GymToGymnasiumWrapper(self._env)
# wrapper_cfgs: Union[str, List]
wrapper_cfgs = self._cfg.env_wrapper
if isinstance(wrapper_cfgs, str):
Expand Down Expand Up @@ -362,4 +364,4 @@
raw_env.__setattr__('spec', spec)
except Exception:
raw_env = self._raw_env
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller, self._is_gymnasium)

Check warning on line 367 in ding/envs/env/ding_env_wrapper.py

View check run for this annotation

Codecov / codecov/patch

ding/envs/env/ding_env_wrapper.py#L367

Added line #L367 was not covered by tests
22 changes: 22 additions & 0 deletions ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import gymnasium
import numpy as np
import pytest
from easydict import EasyDict
Expand Down Expand Up @@ -68,6 +69,27 @@ def test_cartpole_pendulum(self, env_id):
# assert isinstance(action, np.ndarray)
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))

@pytest.mark.unittest
@pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v1'])
def test_cartpole_pendulum_gymnasium(self, env_id):
env = gymnasium.make(env_id)
ding_env = DingEnvWrapper(env=env)
print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
cfg = EasyDict(dict(
collector_env_num=16,
evaluator_env_num=3,
is_train=True,
))
l1 = ding_env.create_collector_env_cfg(cfg)
assert isinstance(l1, list)
l1 = ding_env.create_evaluator_env_cfg(cfg)
assert isinstance(l1, list)
obs = ding_env.reset()
assert isinstance(obs, np.ndarray)
action = ding_env.random_action()
# assert isinstance(action, np.ndarray)
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))

@pytest.mark.envtest
def test_mujoco(self):
env_cfg = EasyDict(
Expand Down
17 changes: 14 additions & 3 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ class GymToGymnasiumWrapper(gym.Wrapper):
Overview:
This class is used to wrap a gymnasium environment to a gym environment.
Interfaces:
__init__, seed, reset
__init__, seed, reset, step
"""

def __init__(self, env: gymnasium.Env) -> None:
Expand Down Expand Up @@ -1470,9 +1470,20 @@ def reset(self) -> np.ndarray:
- observation (:obj:`np.ndarray`): The new observation after reset.
"""
if self.seed is not None:
return self.env.reset(seed=self._seed)
obs, info = self.env.reset(seed=self._seed)
else:
return self.env.reset()
obs, info = self.env.reset()
return obs

def step(self, *args, **kwargs):
"""
Overview:
Execute the given action in the environment, and return the new observation,
reward, done status, and info. To keep consistency with gym, the done status should be the either \
terminated=True or truncated=True.
"""
obs, rew, terminated, truncated, info = self.env.step(*args, **kwargs)
return obs, rew, terminated or truncated, info


@ENV_WRAPPER_REGISTRY.register('reward_in_obs')
Expand Down
49 changes: 49 additions & 0 deletions ding/example/dqn_nstep_gymnasium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import gymnasium as gym
from ditk import logging
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
main_config.exp_name = 'cartpole_dqn_nstep_gymnasium'
main_config.policy.nstep = 3
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(final_ctx_saver(cfg.exp_name))
task.run()


if __name__ == "__main__":
main()
Loading