1
- from typing import Any , Dict ,List , Optional
1
+ from typing import Any , Dict , List , Optional
2
2
import imageio
3
3
import os
4
4
import gymnasium as gymn
7
7
from ding .torch_utils import to_ndarray
8
8
from ding .utils import ENV_REGISTRY
9
9
10
+
10
11
@ENV_REGISTRY .register ('frozen_lake' )
11
12
class FrozenLakeEnv (BaseEnv ):
12
- def __init__ (self ,cfg )-> None :
13
- self ._cfg = cfg
13
+
14
+ def __init__ (self , cfg ) -> None :
15
+ self ._cfg = cfg
14
16
assert self ._cfg .env_id == "FrozenLake-v1" , "yout name is not FrozernLake_v1"
15
17
self ._init_flag = False
16
18
self ._save_replay_bool = False
@@ -19,31 +21,33 @@ def __init__(self,cfg)->None:
19
21
self ._frames = []
20
22
self ._replay_path = False
21
23
22
- def reset (self )-> np .ndarray :
24
+ def reset (self ) -> np .ndarray :
23
25
if not self ._init_flag :
24
- if not self ._cfg .desc :#specify maps non-preloaded maps
25
- self ._env = gymn .make (self ._cfg .env_id ,
26
- desc = self ._cfg .desc ,
27
- map_name = self ._cfg .map_name ,
28
- is_slippery = self ._cfg .is_slippery ,
29
- render_mode = "rgb_array" )
26
+ if not self ._cfg .desc : #specify maps non-preloaded maps
27
+ self ._env = gymn .make (
28
+ self ._cfg .env_id ,
29
+ desc = self ._cfg .desc ,
30
+ map_name = self ._cfg .map_name ,
31
+ is_slippery = self ._cfg .is_slippery ,
32
+ render_mode = "rgb_array"
33
+ )
30
34
self ._observation_space = self ._env .observation_space
31
35
self ._action_space = self ._env .action_space
32
36
self ._reward_space = gymn .spaces .Box (
33
- low = self ._env .reward_range [0 ], high = self ._env .reward_range [1 ], shape = (1 , ), dtype = np .float32
34
- )
37
+ low = self ._env .reward_range [0 ], high = self ._env .reward_range [1 ], shape = (1 , ), dtype = np .float32
38
+ )
35
39
self ._init_flag = True
36
40
self ._eval_episode_return = 0
37
41
if hasattr (self , '_seed' ) and hasattr (self , '_dynamic_seed' ) and self ._dynamic_seed :
38
42
np_seed = 100 * np .random .randint (1 , 1000 )
39
- self ._env_seed = self ._seed + np_seed
43
+ self ._env_seed = self ._seed + np_seed
40
44
elif hasattr (self , '_seed' ):
41
- self ._env_seed = self ._seed
45
+ self ._env_seed = self ._seed
42
46
if hasattr (self , '_seed' ):
43
- obs ,info = self ._env .reset (seed = self ._env_seed )
47
+ obs , info = self ._env .reset (seed = self ._env_seed )
44
48
else :
45
- obs ,info = self ._env .reset ()
46
- obs = self . onehot_encode ( obs )
49
+ obs , info = self ._env .reset ()
50
+ obs = np . eye ( 16 , dtype = np . float32 )[ obs - 1 ]
47
51
return obs
48
52
49
53
def close (self ) -> None :
@@ -57,30 +61,30 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
57
61
np .random .seed (self ._seed )
58
62
59
63
def step (self , action : Dict ) -> BaseEnvTimestep :
60
- obs , rew , terminated , truncated ,info = self ._env .step (action [0 ])
64
+ obs , rew , terminated , truncated , info = self ._env .step (action [0 ])
61
65
self ._eval_episode_return += rew
62
- obs = self . onehot_encode ( obs )
66
+ obs = np . eye ( 16 , dtype = np . float32 )[ obs - 1 ]
63
67
rew = to_ndarray ([rew ])
64
68
if self ._save_replay_bool :
65
- picture = self ._env .render ()
69
+ picture = self ._env .render ()
66
70
self ._frames .append (picture )
67
71
if terminated or truncated :
68
72
done = True
69
- else :
73
+ else :
70
74
done = False
71
75
if done :
72
76
info ['eval_episode_return' ] = self ._eval_episode_return
73
77
if self ._save_replay_bool :
74
- assert self ._replay_path is not None ,"your should have a path"
78
+ assert self ._replay_path is not None , "your should have a path"
75
79
path = os .path .join (
76
- self ._replay_path , '{}_episode_{}.gif' .format (self ._cfg .env_id , self ._save_replay_count )
77
- )
78
- self .frames_to_gif (self ._frames ,path )
80
+ self ._replay_path , '{}_episode_{}.gif' .format (self ._cfg .env_id , self ._save_replay_count )
81
+ )
82
+ self .frames_to_gif (self ._frames , path )
79
83
self ._frames = []
80
84
self ._save_replay_count += 1
81
85
rew = rew .astype (np .float32 )
82
86
return BaseEnvTimestep (obs , rew , done , info )
83
-
87
+
84
88
def random_action (self ) -> Dict :
85
89
raw_action = self ._env .action_space .sample ()
86
90
my_type = type (self ._env .action_space )
@@ -109,7 +113,6 @@ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
109
113
self ._save_replay_count = 0
110
114
self ._frames = []
111
115
112
-
113
116
@staticmethod
114
117
def frames_to_gif (frames : List [imageio .core .util .Array ], gif_path : str , duration : float = 0.1 ) -> None :
115
118
"""
@@ -138,9 +141,4 @@ def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration
138
141
# Clean up temporary image files
139
142
for temp_image_file in temp_image_files :
140
143
os .remove (temp_image_file )
141
-
142
144
print (f"GIF saved as { gif_path } " )
143
-
144
- def onehot_encode (self , x ):
145
- onehot = np .eye (16 , dtype = np .float32 )[x - 1 ]
146
- return onehot
0 commit comments