7
7
import treetensor .numpy as tnp
8
8
9
9
from ding .envs .common .common_function import affine_transform
10
- from ding .envs .env_wrappers import create_env_wrapper
10
+ from ding .envs .env_wrappers import create_env_wrapper , GymToGymnasiumWrapper
11
11
from ding .torch_utils import to_ndarray
12
12
from ding .utils import CloudPickleWrapper
13
13
from .base_env import BaseEnv , BaseEnvTimestep
@@ -23,7 +23,14 @@ class DingEnvWrapper(BaseEnv):
23
23
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
24
24
"""
25
25
26
- def __init__ (self , env : gym .Env = None , cfg : dict = None , seed_api : bool = True , caller : str = 'collector' ) -> None :
26
+ def __init__ (
27
+ self ,
28
+ env : Union [gym .Env , gymnasium .Env ] = None ,
29
+ cfg : dict = None ,
30
+ seed_api : bool = True ,
31
+ caller : str = 'collector' ,
32
+ is_gymnasium : bool = False
33
+ ) -> None :
27
34
"""
28
35
Overview:
29
36
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
@@ -32,17 +39,20 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
32
39
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
33
40
The `cfg` parameter must contain `env_id`.
34
41
Arguments:
35
- - env (:obj:`gym.Env`): An environment instance to be wrapped.
42
+ - env (:obj:`Union[ gym.Env, gymnasium.Env] `): An environment instance to be wrapped.
36
43
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
37
44
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
38
45
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
39
46
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
47
+ - is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \
48
+ the environment is a gym environment.
40
49
"""
41
50
self ._env = None
42
51
self ._raw_env = env
43
52
self ._cfg = cfg
44
53
self ._seed_api = seed_api # some env may disable `env.seed` api
45
54
self ._caller = caller
55
+
46
56
if self ._cfg is None :
47
57
self ._cfg = {}
48
58
self ._cfg = EasyDict (self ._cfg )
@@ -55,6 +65,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
55
65
if 'env_id' not in self ._cfg :
56
66
self ._cfg .env_id = None
57
67
if env is not None :
68
+ self ._is_gymnasium = isinstance (env , gymnasium .Env )
58
69
self ._env = env
59
70
self ._wrap_env (caller )
60
71
self ._observation_space = self ._env .observation_space
@@ -66,6 +77,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
66
77
self ._init_flag = True
67
78
else :
68
79
assert 'env_id' in self ._cfg
80
+ self ._is_gymnasium = is_gymnasium
69
81
self ._init_flag = False
70
82
self ._observation_space = None
71
83
self ._action_space = None
@@ -82,7 +94,8 @@ def reset(self) -> np.ndarray:
82
94
- obs (:obj:`Dict`): The new observation after reset.
83
95
"""
84
96
if not self ._init_flag :
85
- self ._env = gym .make (self ._cfg .env_id )
97
+ gym_proxy = gymnasium if self ._is_gymnasium else gym
98
+ self ._env = gym_proxy .make (self ._cfg .env_id )
86
99
self ._wrap_env (self ._caller )
87
100
self ._observation_space = self ._env .observation_space
88
101
self ._action_space = self ._env .action_space
@@ -98,29 +111,16 @@ def reset(self) -> np.ndarray:
98
111
name_prefix = 'rl-video-{}' .format (id (self ))
99
112
)
100
113
self ._replay_path = None
101
- if isinstance (self ._env , gym .Env ):
102
- if hasattr (self , '_seed' ) and hasattr (self , '_dynamic_seed' ) and self ._dynamic_seed :
103
- np_seed = 100 * np .random .randint (1 , 1000 )
104
- if self ._seed_api :
105
- self ._env .seed (self ._seed + np_seed )
106
- self ._action_space .seed (self ._seed + np_seed )
107
- elif hasattr (self , '_seed' ):
108
- if self ._seed_api :
109
- self ._env .seed (self ._seed )
110
- self ._action_space .seed (self ._seed )
111
- obs = self ._env .reset ()
112
- elif isinstance (self ._env , gymnasium .Env ):
113
- if hasattr (self , '_seed' ) and hasattr (self , '_dynamic_seed' ) and self ._dynamic_seed :
114
- np_seed = 100 * np .random .randint (1 , 1000 )
115
- self ._action_space .seed (self ._seed + np_seed )
116
- obs = self ._env .reset (seed = self ._seed + np_seed )
117
- elif hasattr (self , '_seed' ):
118
- self ._action_space .seed (self ._seed )
119
- obs = self ._env .reset (seed = self ._seed )
120
- else :
121
- obs = self ._env .reset ()
122
- else :
123
- raise RuntimeError ("not support env type: {}" .format (type (self ._env )))
114
+ if hasattr (self , '_seed' ) and hasattr (self , '_dynamic_seed' ) and self ._dynamic_seed :
115
+ np_seed = 100 * np .random .randint (1 , 1000 )
116
+ if self ._seed_api :
117
+ self ._env .seed (self ._seed + np_seed )
118
+ self ._action_space .seed (self ._seed + np_seed )
119
+ elif hasattr (self , '_seed' ):
120
+ if self ._seed_api :
121
+ self ._env .seed (self ._seed )
122
+ self ._action_space .seed (self ._seed )
123
+ obs = self ._env .reset ()
124
124
if self .observation_space .dtype == np .float32 :
125
125
obs = to_ndarray (obs , dtype = np .float32 )
126
126
else :
@@ -221,7 +221,7 @@ def random_action(self) -> np.ndarray:
221
221
random_action = self .action_space .sample ()
222
222
if isinstance (random_action , np .ndarray ):
223
223
pass
224
- elif isinstance (random_action , int ):
224
+ elif isinstance (random_action , ( int , np . int64 ) ):
225
225
random_action = to_ndarray ([random_action ], dtype = np .int64 )
226
226
elif isinstance (random_action , dict ):
227
227
random_action = to_ndarray (random_action )
@@ -241,6 +241,8 @@ def _wrap_env(self, caller: str = 'collector') -> None:
241
241
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
242
242
Different caller may need different wrappers. Default is 'collector'.
243
243
"""
244
+ if self ._is_gymnasium :
245
+ self ._env = GymToGymnasiumWrapper (self ._env )
244
246
# wrapper_cfgs: Union[str, List]
245
247
wrapper_cfgs = self ._cfg .env_wrapper
246
248
if isinstance (wrapper_cfgs , str ):
@@ -362,4 +364,4 @@ def clone(self, caller: str = 'collector') -> BaseEnv:
362
364
raw_env .__setattr__ ('spec' , spec )
363
365
except Exception :
364
366
raw_env = self ._raw_env
365
- return DingEnvWrapper (raw_env , self ._cfg , self ._seed_api , caller )
367
+ return DingEnvWrapper (raw_env , self ._cfg , self ._seed_api , caller , self . _is_gymnasium )
0 commit comments