diff --git a/ding/envs/env/ding_env_wrapper.py b/ding/envs/env/ding_env_wrapper.py index dc67e826bd..176d96b68e 100644 --- a/ding/envs/env/ding_env_wrapper.py +++ b/ding/envs/env/ding_env_wrapper.py @@ -7,7 +7,7 @@ import treetensor.numpy as tnp from ding.envs.common.common_function import affine_transform -from ding.envs.env_wrappers import create_env_wrapper +from ding.envs.env_wrappers import create_env_wrapper, GymToGymnasiumWrapper from ding.torch_utils import to_ndarray from ding.utils import CloudPickleWrapper from .base_env import BaseEnv, BaseEnvTimestep @@ -23,7 +23,14 @@ class DingEnvWrapper(BaseEnv): create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone """ - def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None: + def __init__( + self, + env: Union[gym.Env, gymnasium.Env] = None, + cfg: dict = None, + seed_api: bool = True, + caller: str = 'collector', + is_gymnasium: bool = False + ) -> None: """ Overview: Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \ @@ -32,17 +39,20 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, usually used in simple environments. For the latter, i.e., a config to create an environment instance: \ The `cfg` parameter must contain `env_id`. Arguments: - - env (:obj:`gym.Env`): An environment instance to be wrapped. + - env (:obj:`Union[gym.Env, gymnasium.Env]`): An environment instance to be wrapped. - cfg (:obj:`dict`): The configuration dictionary to create an environment instance. - seed_api (:obj:`bool`): Whether to use seed API. Defaults to True. - caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \ ``evaluator``. Different caller may need different wrappers. Default is 'collector'. + - is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \ + the environment is a gym environment. """ self._env = None self._raw_env = env self._cfg = cfg self._seed_api = seed_api # some env may disable `env.seed` api self._caller = caller + if self._cfg is None: self._cfg = {} self._cfg = EasyDict(self._cfg) @@ -55,6 +65,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, if 'env_id' not in self._cfg: self._cfg.env_id = None if env is not None: + self._is_gymnasium = isinstance(env, gymnasium.Env) self._env = env self._wrap_env(caller) self._observation_space = self._env.observation_space @@ -66,6 +77,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, self._init_flag = True else: assert 'env_id' in self._cfg + self._is_gymnasium = is_gymnasium self._init_flag = False self._observation_space = None self._action_space = None @@ -82,7 +94,8 @@ def reset(self) -> np.ndarray: - obs (:obj:`Dict`): The new observation after reset. """ if not self._init_flag: - self._env = gym.make(self._cfg.env_id) + gym_proxy = gymnasium if self._is_gymnasium else gym + self._env = gym_proxy.make(self._cfg.env_id) self._wrap_env(self._caller) self._observation_space = self._env.observation_space self._action_space = self._env.action_space @@ -98,29 +111,16 @@ def reset(self) -> np.ndarray: name_prefix='rl-video-{}'.format(id(self)) ) self._replay_path = None - if isinstance(self._env, gym.Env): - if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: - np_seed = 100 * np.random.randint(1, 1000) - if self._seed_api: - self._env.seed(self._seed + np_seed) - self._action_space.seed(self._seed + np_seed) - elif hasattr(self, '_seed'): - if self._seed_api: - self._env.seed(self._seed) - self._action_space.seed(self._seed) - obs = self._env.reset() - elif isinstance(self._env, gymnasium.Env): - if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: - np_seed = 100 * np.random.randint(1, 1000) - self._action_space.seed(self._seed + np_seed) - obs = self._env.reset(seed=self._seed + np_seed) - elif hasattr(self, '_seed'): - self._action_space.seed(self._seed) - obs = self._env.reset(seed=self._seed) - else: - obs = self._env.reset() - else: - raise RuntimeError("not support env type: {}".format(type(self._env))) + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + if self._seed_api: + self._env.seed(self._seed + np_seed) + self._action_space.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + if self._seed_api: + self._env.seed(self._seed) + self._action_space.seed(self._seed) + obs = self._env.reset() if self.observation_space.dtype == np.float32: obs = to_ndarray(obs, dtype=np.float32) else: @@ -221,7 +221,7 @@ def random_action(self) -> np.ndarray: random_action = self.action_space.sample() if isinstance(random_action, np.ndarray): pass - elif isinstance(random_action, int): + elif isinstance(random_action, (int, np.int64)): random_action = to_ndarray([random_action], dtype=np.int64) elif isinstance(random_action, dict): random_action = to_ndarray(random_action) @@ -241,6 +241,8 @@ def _wrap_env(self, caller: str = 'collector') -> None: - caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \ Different caller may need different wrappers. Default is 'collector'. """ + if self._is_gymnasium: + self._env = GymToGymnasiumWrapper(self._env) # wrapper_cfgs: Union[str, List] wrapper_cfgs = self._cfg.env_wrapper if isinstance(wrapper_cfgs, str): @@ -362,4 +364,4 @@ def clone(self, caller: str = 'collector') -> BaseEnv: raw_env.__setattr__('spec', spec) except Exception: raw_env = self._raw_env - return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller) + return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller, self._is_gymnasium) diff --git a/ding/envs/env/tests/test_ding_env_wrapper.py b/ding/envs/env/tests/test_ding_env_wrapper.py index 7d53adbfd3..0c9cd9abb7 100644 --- a/ding/envs/env/tests/test_ding_env_wrapper.py +++ b/ding/envs/env/tests/test_ding_env_wrapper.py @@ -1,4 +1,5 @@ import gym +import gymnasium import numpy as np import pytest from easydict import EasyDict @@ -68,6 +69,27 @@ def test_cartpole_pendulum(self, env_id): # assert isinstance(action, np.ndarray) print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space)) + @pytest.mark.unittest + @pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v1']) + def test_cartpole_pendulum_gymnasium(self, env_id): + env = gymnasium.make(env_id) + ding_env = DingEnvWrapper(env=env) + print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space) + cfg = EasyDict(dict( + collector_env_num=16, + evaluator_env_num=3, + is_train=True, + )) + l1 = ding_env.create_collector_env_cfg(cfg) + assert isinstance(l1, list) + l1 = ding_env.create_evaluator_env_cfg(cfg) + assert isinstance(l1, list) + obs = ding_env.reset() + assert isinstance(obs, np.ndarray) + action = ding_env.random_action() + # assert isinstance(action, np.ndarray) + print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space)) + @pytest.mark.envtest def test_mujoco(self): env_cfg = EasyDict( diff --git a/ding/envs/env_wrappers/env_wrappers.py b/ding/envs/env_wrappers/env_wrappers.py index 08b1ce4eb1..a33938d20a 100644 --- a/ding/envs/env_wrappers/env_wrappers.py +++ b/ding/envs/env_wrappers/env_wrappers.py @@ -1438,7 +1438,7 @@ class GymToGymnasiumWrapper(gym.Wrapper): Overview: This class is used to wrap a gymnasium environment to a gym environment. Interfaces: - __init__, seed, reset + __init__, seed, reset, step """ def __init__(self, env: gymnasium.Env) -> None: @@ -1470,9 +1470,20 @@ def reset(self) -> np.ndarray: - observation (:obj:`np.ndarray`): The new observation after reset. """ if self.seed is not None: - return self.env.reset(seed=self._seed) + obs, info = self.env.reset(seed=self._seed) else: - return self.env.reset() + obs, info = self.env.reset() + return obs + + def step(self, *args, **kwargs): + """ + Overview: + Execute the given action in the environment, and return the new observation, + reward, done status, and info. To keep consistency with gym, the done status should be the either \ + terminated=True or truncated=True. + """ + obs, rew, terminated, truncated, info = self.env.step(*args, **kwargs) + return obs, rew, terminated or truncated, info @ENV_WRAPPER_REGISTRY.register('reward_in_obs') diff --git a/ding/example/dqn_nstep_gymnasium.py b/ding/example/dqn_nstep_gymnasium.py new file mode 100644 index 0000000000..f712ce8cca --- /dev/null +++ b/ding/example/dqn_nstep_gymnasium.py @@ -0,0 +1,49 @@ +import gymnasium as gym +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver +from ding.utils import set_pkg_seed +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + main_config.exp_name = 'cartpole_dqn_nstep_gymnasium' + main_config.policy.nstep = 3 + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(final_ctx_saver(cfg.exp_name)) + task.run() + + +if __name__ == "__main__": + main()