Skip to content

Commit d41ece2

Browse files
committed
format my code
1 parent 0576208 commit d41ece2

File tree

5 files changed

+57
-59
lines changed

5 files changed

+57
-59
lines changed

ding/example/dqn_frozen_lake.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
from dizoo.frozen_lake.config.frozen_lake_dqn_config import main_config, create_config
1313
from dizoo.frozen_lake.envs import FrozenLakeEnv
1414

15+
1516
def main():
1617
logging.getLogger().setLevel(logging.INFO)
17-
main_config.exp_name = 'cartpole_dqn_nstep'
1818
main_config.policy.nstep = 5
1919
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
2020
with task.start(async_mode=False, ctx=OnlineRLContext()):
2121
collector_env = BaseEnvManagerV2(
22-
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.collector_env_num)],
23-
cfg=cfg.env.manager
22+
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
2423
)
2524
evaluator_env = BaseEnvManagerV2(
26-
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.evaluator_env_num)],
27-
cfg=cfg.env.manager
25+
env_fn=[lambda: FrozenLakeEnv(cfg=cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
2826
)
2927
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
3028

@@ -44,4 +42,4 @@ def main():
4442

4543

4644
if __name__ == "__main__":
47-
main()
45+
main()

dizoo/frozen_lake/config/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .frozen_lake_dqn_config import main_config, create_config
1+
from .frozen_lake_dqn_config import main_config, create_config

dizoo/frozen_lake/config/frozen_lake_dqn_config.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,29 @@
66
collector_env_num=8,
77
evaluator_env_num=5,
88
n_evaluator_episode=10,
9-
env_id = 'FrozenLake-v1',
10-
desc = None,
11-
map_name = "4x4",
12-
is_slippery = False,
13-
save_replay_gif = False,
9+
env_id='FrozenLake-v1',
10+
desc=None,
11+
map_name="4x4",
12+
is_slippery=False,
13+
save_replay_gif=False,
1414
),
15-
16-
policy = dict(
15+
policy=dict(
1716
cuda=True,
1817
load_path='frozen_lake_seed0/ckpt/ckpt_best.pth.tar',
19-
model = dict(
18+
model=dict(
2019
obs_shape=16,
2120
action_shape=4,
2221
encoder_hidden_size_list=[128, 128, 64],
2322
dueling=True,
2423
),
25-
nstep = 3,
24+
nstep=3,
2625
discount_factor=0.97,
2726
learn=dict(
2827
update_per_collect=5,
2928
batch_size=256,
3029
learning_rate=0.001,
3130
),
32-
collect = dict(n_sample=10),
31+
collect=dict(n_sample=10),
3332
eval=dict(evaluator=dict(eval_freq=40, )),
3433
other=dict(
3534
eps=dict(
@@ -62,4 +61,4 @@
6261
if __name__ == "__main__":
6362
# or you can enter `ding -m serial -c frozen_lake_dqn_config.py -s 0`
6463
from ding.entry import serial_pipeline
65-
serial_pipeline((main_config, create_config), max_env_step=5000,seed=0)
64+
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)

dizoo/frozen_lake/envs/frozen_lake_env.py

+30-32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict,List, Optional
1+
from typing import Any, Dict, List, Optional
22
import imageio
33
import os
44
import gymnasium as gymn
@@ -7,10 +7,12 @@
77
from ding.torch_utils import to_ndarray
88
from ding.utils import ENV_REGISTRY
99

10+
1011
@ENV_REGISTRY.register('frozen_lake')
1112
class FrozenLakeEnv(BaseEnv):
12-
def __init__(self,cfg)->None:
13-
self._cfg=cfg
13+
14+
def __init__(self, cfg) -> None:
15+
self._cfg = cfg
1416
assert self._cfg.env_id == "FrozenLake-v1", "yout name is not FrozernLake_v1"
1517
self._init_flag = False
1618
self._save_replay_bool = False
@@ -19,31 +21,33 @@ def __init__(self,cfg)->None:
1921
self._frames = []
2022
self._replay_path = False
2123

22-
def reset(self)-> np.ndarray:
24+
def reset(self) -> np.ndarray:
2325
if not self._init_flag:
24-
if not self._cfg.desc :#specify maps non-preloaded maps
25-
self._env = gymn.make(self._cfg.env_id,
26-
desc=self._cfg.desc,
27-
map_name=self._cfg.map_name,
28-
is_slippery=self._cfg.is_slippery,
29-
render_mode="rgb_array")
26+
if not self._cfg.desc: #specify maps non-preloaded maps
27+
self._env = gymn.make(
28+
self._cfg.env_id,
29+
desc=self._cfg.desc,
30+
map_name=self._cfg.map_name,
31+
is_slippery=self._cfg.is_slippery,
32+
render_mode="rgb_array"
33+
)
3034
self._observation_space = self._env.observation_space
3135
self._action_space = self._env.action_space
3236
self._reward_space = gymn.spaces.Box(
33-
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
34-
)
37+
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
38+
)
3539
self._init_flag = True
3640
self._eval_episode_return = 0
3741
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
3842
np_seed = 100 * np.random.randint(1, 1000)
39-
self._env_seed=self._seed + np_seed
43+
self._env_seed = self._seed + np_seed
4044
elif hasattr(self, '_seed'):
41-
self._env_seed=self._seed
45+
self._env_seed = self._seed
4246
if hasattr(self, '_seed'):
43-
obs,info = self._env.reset(seed=self._env_seed)
47+
obs, info = self._env.reset(seed=self._env_seed)
4448
else:
45-
obs,info = self._env.reset()
46-
obs = self.onehot_encode(obs)
49+
obs, info = self._env.reset()
50+
obs = np.eye(16, dtype=np.float32)[obs - 1]
4751
return obs
4852

4953
def close(self) -> None:
@@ -57,30 +61,30 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
5761
np.random.seed(self._seed)
5862

5963
def step(self, action: Dict) -> BaseEnvTimestep:
60-
obs, rew, terminated, truncated,info = self._env.step(action[0])
64+
obs, rew, terminated, truncated, info = self._env.step(action[0])
6165
self._eval_episode_return += rew
62-
obs = self.onehot_encode(obs)
66+
obs = np.eye(16, dtype=np.float32)[obs - 1]
6367
rew = to_ndarray([rew])
6468
if self._save_replay_bool:
65-
picture=self._env.render()
69+
picture = self._env.render()
6670
self._frames.append(picture)
6771
if terminated or truncated:
6872
done = True
69-
else :
73+
else:
7074
done = False
7175
if done:
7276
info['eval_episode_return'] = self._eval_episode_return
7377
if self._save_replay_bool:
74-
assert self._replay_path is not None,"your should have a path"
78+
assert self._replay_path is not None, "your should have a path"
7579
path = os.path.join(
76-
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
77-
)
78-
self.frames_to_gif(self._frames,path)
80+
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
81+
)
82+
self.frames_to_gif(self._frames, path)
7983
self._frames = []
8084
self._save_replay_count += 1
8185
rew = rew.astype(np.float32)
8286
return BaseEnvTimestep(obs, rew, done, info)
83-
87+
8488
def random_action(self) -> Dict:
8589
raw_action = self._env.action_space.sample()
8690
my_type = type(self._env.action_space)
@@ -109,7 +113,6 @@ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
109113
self._save_replay_count = 0
110114
self._frames = []
111115

112-
113116
@staticmethod
114117
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None:
115118
"""
@@ -138,9 +141,4 @@ def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration
138141
# Clean up temporary image files
139142
for temp_image_file in temp_image_files:
140143
os.remove(temp_image_file)
141-
142144
print(f"GIF saved as {gif_path}")
143-
144-
def onehot_encode(self, x):
145-
onehot = np.eye(16, dtype=np.float32)[x - 1]
146-
return onehot

dizoo/frozen_lake/envs/test_frozen_lake_env.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33
from dizoo.frozen_lake.envs import FrozenLakeEnv
44
from easydict import EasyDict
55

6+
67
@pytest.mark.envtest
78
class TestGymHybridEnv:
9+
810
def test_my_lake(self):
911
env = FrozenLakeEnv(
10-
EasyDict(
11-
{
12-
'env_id': 'FrozenLake-v1',
13-
'desc': None,
14-
'map_name': "4x4",
15-
'is_slippery': False,
16-
}
17-
))
12+
EasyDict({
13+
'env_id': 'FrozenLake-v1',
14+
'desc': None,
15+
'map_name': "4x4",
16+
'is_slippery': False,
17+
})
18+
)
1819
for _ in range(5):
1920
env.seed(314, dynamic_seed=False)
2021
assert env._seed == 314
2122
obs = env.reset()
22-
assert obs.shape == (16,), "Considering the one-hot encoding format, your observation should have a dimensionality of 16."
23+
assert obs.shape == (
24+
16,
25+
), "Considering the one-hot encoding format, your observation should have a dimensionality of 16."
2326
for i in range(10):
2427
env.enable_save_replay("./video")
2528
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,

0 commit comments

Comments
 (0)