Skip to content

Commit 73ff16f

Browse files
authored
feature(wrh): add taxi env latest version and dqn config (#807)
* update taxi env
1 parent 91bc342 commit 73ff16f

File tree

4 files changed

+36
-26
lines changed

4 files changed

+36
-26
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
324324
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
325325
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
326326
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
327-
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link <br> env tutorial <br> 环境指南 |
327+
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/taxi_zh.html) |
328328

329329

330330

dizoo/taxi/config/taxi_dqn_config.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,45 @@
11
from easydict import EasyDict
22

33
taxi_dqn_config = dict(
4-
exp_name='taxi_seed0',
4+
exp_name='taxi_dqn_seed0',
55
env=dict(
66
collector_env_num=8,
7-
evaluator_env_num=8,
8-
n_evaluator_episode=10,
9-
max_episode_steps=300,
10-
env_id="Taxi-v3"
7+
evaluator_env_num=8,
8+
n_evaluator_episode=8,
9+
stop_value=20,
10+
max_episode_steps=60,
11+
env_id="Taxi-v3"
1112
),
1213
policy=dict(
1314
cuda=True,
14-
load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar",
1515
model=dict(
16-
obs_shape=4,
16+
obs_shape=34,
1717
action_shape=6,
18-
encoder_hidden_size_list=[256, 128, 64]
18+
encoder_hidden_size_list=[128, 128]
1919
),
20+
random_collect_size=5000,
2021
nstep=3,
21-
discount_factor=0.98,
22+
discount_factor=0.99,
2223
learn=dict(
23-
update_per_collect=5,
24-
batch_size=128,
25-
learning_rate=0.001,
24+
update_per_collect=10,
25+
batch_size=64,
26+
learning_rate=0.0001,
27+
learner=dict(
28+
hook=dict(
29+
log_show_after_iter=1000,
30+
)
31+
),
2632
),
27-
collect=dict(n_sample=10),
28-
eval=dict(evaluator=dict(eval_freq=5, )),
33+
collect=dict(n_sample=32),
34+
eval=dict(evaluator=dict(eval_freq=1000, )),
2935
other=dict(
3036
eps=dict(
3137
type="linear",
32-
start=0.8,
33-
end=0.1,
34-
decay=10000
35-
),
36-
replay_buffer=dict(replay_buffer_size=20000,),
38+
start=1,
39+
end=0.05,
40+
decay=3000000
41+
),
42+
replay_buffer=dict(replay_buffer_size=100000,),
3743
),
3844
)
3945
)
@@ -55,4 +61,4 @@
5561

5662
if __name__ == "__main__":
5763
from ding.entry import serial_pipeline
58-
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
64+
serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0)

dizoo/taxi/envs/taxi_env.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
9393
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
9494
if replay_path is None:
9595
replay_path = './video'
96-
if not os.path.exists(replay_path):
97-
os.makedirs(replay_path)
96+
if not os.path.exists(replay_path):
97+
os.makedirs(replay_path)
9898
self._replay_path = replay_path
9999
self._save_replay = True
100100
self._save_replay_count = 0
@@ -118,7 +118,11 @@ def random_action(self) -> np.ndarray:
118118
#todo encode the state into a vector
119119
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
120120
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
121-
return to_ndarray([taxi_row, taxi_col, passenger_location, destination])
121+
encoded_obs = np.zeros(34)
122+
encoded_obs[5 * taxi_row + taxi_col] = 1
123+
encoded_obs[25 + passenger_location] = 1
124+
encoded_obs[30 + destination] = 1
125+
return to_ndarray(encoded_obs)
122126

123127
@property
124128
def observation_space(self) -> Space:

dizoo/taxi/envs/test_taxi_env.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_naive(self):
1616
env.seed(314, dynamic_seed=False)
1717
assert env._seed == 314
1818
obs = env.reset()
19-
assert obs.shape == (4, )
19+
assert obs.shape == (34, )
2020
for _ in range(5):
2121
env.reset()
2222
np.random.seed(314)
@@ -32,7 +32,7 @@ def test_naive(self):
3232
print(f"Your timestep in wrapped mode is: {timestep}")
3333
assert isinstance(timestep.obs, np.ndarray)
3434
assert isinstance(timestep.done, bool)
35-
assert timestep.obs.shape == (4, )
35+
assert timestep.obs.shape == (34, )
3636
assert timestep.reward.shape == (1, )
3737
assert timestep.reward >= env.reward_space.low
3838
assert timestep.reward <= env.reward_space.high

0 commit comments

Comments
 (0)