Skip to content

Commit 96ccaed

Browse files
authored
feature(nyz): adapt DingEnvWrapper to gymnasium (#817)
1 parent 7f95159 commit 96ccaed

File tree

4 files changed

+116
-32
lines changed

4 files changed

+116
-32
lines changed

ding/envs/env/ding_env_wrapper.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import treetensor.numpy as tnp
88

99
from ding.envs.common.common_function import affine_transform
10-
from ding.envs.env_wrappers import create_env_wrapper
10+
from ding.envs.env_wrappers import create_env_wrapper, GymToGymnasiumWrapper
1111
from ding.torch_utils import to_ndarray
1212
from ding.utils import CloudPickleWrapper
1313
from .base_env import BaseEnv, BaseEnvTimestep
@@ -23,7 +23,14 @@ class DingEnvWrapper(BaseEnv):
2323
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
2424
"""
2525

26-
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
26+
def __init__(
27+
self,
28+
env: Union[gym.Env, gymnasium.Env] = None,
29+
cfg: dict = None,
30+
seed_api: bool = True,
31+
caller: str = 'collector',
32+
is_gymnasium: bool = False
33+
) -> None:
2734
"""
2835
Overview:
2936
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
@@ -32,17 +39,20 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
3239
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
3340
The `cfg` parameter must contain `env_id`.
3441
Arguments:
35-
- env (:obj:`gym.Env`): An environment instance to be wrapped.
42+
- env (:obj:`Union[gym.Env, gymnasium.Env]`): An environment instance to be wrapped.
3643
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
3744
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
3845
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
3946
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
47+
- is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \
48+
the environment is a gym environment.
4049
"""
4150
self._env = None
4251
self._raw_env = env
4352
self._cfg = cfg
4453
self._seed_api = seed_api # some env may disable `env.seed` api
4554
self._caller = caller
55+
4656
if self._cfg is None:
4757
self._cfg = {}
4858
self._cfg = EasyDict(self._cfg)
@@ -55,6 +65,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
5565
if 'env_id' not in self._cfg:
5666
self._cfg.env_id = None
5767
if env is not None:
68+
self._is_gymnasium = isinstance(env, gymnasium.Env)
5869
self._env = env
5970
self._wrap_env(caller)
6071
self._observation_space = self._env.observation_space
@@ -66,6 +77,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
6677
self._init_flag = True
6778
else:
6879
assert 'env_id' in self._cfg
80+
self._is_gymnasium = is_gymnasium
6981
self._init_flag = False
7082
self._observation_space = None
7183
self._action_space = None
@@ -82,7 +94,8 @@ def reset(self) -> np.ndarray:
8294
- obs (:obj:`Dict`): The new observation after reset.
8395
"""
8496
if not self._init_flag:
85-
self._env = gym.make(self._cfg.env_id)
97+
gym_proxy = gymnasium if self._is_gymnasium else gym
98+
self._env = gym_proxy.make(self._cfg.env_id)
8699
self._wrap_env(self._caller)
87100
self._observation_space = self._env.observation_space
88101
self._action_space = self._env.action_space
@@ -98,29 +111,16 @@ def reset(self) -> np.ndarray:
98111
name_prefix='rl-video-{}'.format(id(self))
99112
)
100113
self._replay_path = None
101-
if isinstance(self._env, gym.Env):
102-
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
103-
np_seed = 100 * np.random.randint(1, 1000)
104-
if self._seed_api:
105-
self._env.seed(self._seed + np_seed)
106-
self._action_space.seed(self._seed + np_seed)
107-
elif hasattr(self, '_seed'):
108-
if self._seed_api:
109-
self._env.seed(self._seed)
110-
self._action_space.seed(self._seed)
111-
obs = self._env.reset()
112-
elif isinstance(self._env, gymnasium.Env):
113-
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
114-
np_seed = 100 * np.random.randint(1, 1000)
115-
self._action_space.seed(self._seed + np_seed)
116-
obs = self._env.reset(seed=self._seed + np_seed)
117-
elif hasattr(self, '_seed'):
118-
self._action_space.seed(self._seed)
119-
obs = self._env.reset(seed=self._seed)
120-
else:
121-
obs = self._env.reset()
122-
else:
123-
raise RuntimeError("not support env type: {}".format(type(self._env)))
114+
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
115+
np_seed = 100 * np.random.randint(1, 1000)
116+
if self._seed_api:
117+
self._env.seed(self._seed + np_seed)
118+
self._action_space.seed(self._seed + np_seed)
119+
elif hasattr(self, '_seed'):
120+
if self._seed_api:
121+
self._env.seed(self._seed)
122+
self._action_space.seed(self._seed)
123+
obs = self._env.reset()
124124
if self.observation_space.dtype == np.float32:
125125
obs = to_ndarray(obs, dtype=np.float32)
126126
else:
@@ -221,7 +221,7 @@ def random_action(self) -> np.ndarray:
221221
random_action = self.action_space.sample()
222222
if isinstance(random_action, np.ndarray):
223223
pass
224-
elif isinstance(random_action, int):
224+
elif isinstance(random_action, (int, np.int64)):
225225
random_action = to_ndarray([random_action], dtype=np.int64)
226226
elif isinstance(random_action, dict):
227227
random_action = to_ndarray(random_action)
@@ -241,6 +241,8 @@ def _wrap_env(self, caller: str = 'collector') -> None:
241241
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
242242
Different caller may need different wrappers. Default is 'collector'.
243243
"""
244+
if self._is_gymnasium:
245+
self._env = GymToGymnasiumWrapper(self._env)
244246
# wrapper_cfgs: Union[str, List]
245247
wrapper_cfgs = self._cfg.env_wrapper
246248
if isinstance(wrapper_cfgs, str):
@@ -362,4 +364,4 @@ def clone(self, caller: str = 'collector') -> BaseEnv:
362364
raw_env.__setattr__('spec', spec)
363365
except Exception:
364366
raw_env = self._raw_env
365-
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)
367+
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller, self._is_gymnasium)

ding/envs/env/tests/test_ding_env_wrapper.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gym
2+
import gymnasium
23
import numpy as np
34
import pytest
45
from easydict import EasyDict
@@ -68,6 +69,27 @@ def test_cartpole_pendulum(self, env_id):
6869
# assert isinstance(action, np.ndarray)
6970
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))
7071

72+
@pytest.mark.unittest
73+
@pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v1'])
74+
def test_cartpole_pendulum_gymnasium(self, env_id):
75+
env = gymnasium.make(env_id)
76+
ding_env = DingEnvWrapper(env=env)
77+
print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
78+
cfg = EasyDict(dict(
79+
collector_env_num=16,
80+
evaluator_env_num=3,
81+
is_train=True,
82+
))
83+
l1 = ding_env.create_collector_env_cfg(cfg)
84+
assert isinstance(l1, list)
85+
l1 = ding_env.create_evaluator_env_cfg(cfg)
86+
assert isinstance(l1, list)
87+
obs = ding_env.reset()
88+
assert isinstance(obs, np.ndarray)
89+
action = ding_env.random_action()
90+
# assert isinstance(action, np.ndarray)
91+
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))
92+
7193
@pytest.mark.envtest
7294
def test_mujoco(self):
7395
env_cfg = EasyDict(

ding/envs/env_wrappers/env_wrappers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,7 @@ class GymToGymnasiumWrapper(gym.Wrapper):
14381438
Overview:
14391439
This class is used to wrap a gymnasium environment to a gym environment.
14401440
Interfaces:
1441-
__init__, seed, reset
1441+
__init__, seed, reset, step
14421442
"""
14431443

14441444
def __init__(self, env: gymnasium.Env) -> None:
@@ -1470,9 +1470,20 @@ def reset(self) -> np.ndarray:
14701470
- observation (:obj:`np.ndarray`): The new observation after reset.
14711471
"""
14721472
if self.seed is not None:
1473-
return self.env.reset(seed=self._seed)
1473+
obs, info = self.env.reset(seed=self._seed)
14741474
else:
1475-
return self.env.reset()
1475+
obs, info = self.env.reset()
1476+
return obs
1477+
1478+
def step(self, *args, **kwargs):
1479+
"""
1480+
Overview:
1481+
Execute the given action in the environment, and return the new observation,
1482+
reward, done status, and info. To keep consistency with gym, the done status should be the either \
1483+
terminated=True or truncated=True.
1484+
"""
1485+
obs, rew, terminated, truncated, info = self.env.step(*args, **kwargs)
1486+
return obs, rew, terminated or truncated, info
14761487

14771488

14781489
@ENV_WRAPPER_REGISTRY.register('reward_in_obs')

ding/example/dqn_nstep_gymnasium.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import gymnasium as gym
2+
from ditk import logging
3+
from ding.model import DQN
4+
from ding.policy import DQNPolicy
5+
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
6+
from ding.data import DequeBuffer
7+
from ding.config import compile_config
8+
from ding.framework import task
9+
from ding.framework.context import OnlineRLContext
10+
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
11+
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
12+
from ding.utils import set_pkg_seed
13+
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
14+
15+
16+
def main():
17+
logging.getLogger().setLevel(logging.INFO)
18+
main_config.exp_name = 'cartpole_dqn_nstep_gymnasium'
19+
main_config.policy.nstep = 3
20+
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
21+
with task.start(async_mode=False, ctx=OnlineRLContext()):
22+
collector_env = BaseEnvManagerV2(
23+
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
24+
cfg=cfg.env.manager
25+
)
26+
evaluator_env = BaseEnvManagerV2(
27+
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
28+
cfg=cfg.env.manager
29+
)
30+
31+
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
32+
33+
model = DQN(**cfg.policy.model)
34+
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
35+
policy = DQNPolicy(cfg.policy, model=model)
36+
37+
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
38+
task.use(eps_greedy_handler(cfg))
39+
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
40+
task.use(nstep_reward_enhancer(cfg))
41+
task.use(data_pusher(cfg, buffer_))
42+
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
43+
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
44+
task.use(final_ctx_saver(cfg.exp_name))
45+
task.run()
46+
47+
48+
if __name__ == "__main__":
49+
main()

0 commit comments

Comments
 (0)