Skip to content

Commit 7f95159

Browse files
authored
feature(zym): update ppo config to support discrete action space (#809)
* feat (zym): update ppo config to support discrete action space
1 parent 35ec39e commit 7f95159

File tree

3 files changed

+54
-43
lines changed

3 files changed

+54
-43
lines changed

dizoo/atari/config/serial/enduro/enduro_onppo_config.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
enduro_onppo_config = dict(
44
exp_name='enduro_onppo_seed0',
55
env=dict(
6-
collector_env_num=64,
6+
collector_env_num=8,
77
evaluator_env_num=8,
88
n_evaluator_episode=8,
99
stop_value=10000000000,
@@ -14,38 +14,45 @@
1414
),
1515
policy=dict(
1616
cuda=True,
17+
recompute_adv=True,
18+
action_space='discrete',
1719
model=dict(
1820
obs_shape=[4, 84, 84],
1921
action_shape=9,
20-
encoder_hidden_size_list=[32, 64, 64, 128],
21-
actor_head_hidden_size=128,
22-
critic_head_hidden_size=128,
23-
critic_head_layer_num=2,
22+
action_space='discrete',
23+
encoder_hidden_size_list=[32, 64, 64, 512],
24+
actor_head_layer_num=0,
25+
critic_head_layer_num=0,
26+
actor_head_hidden_size=512,
27+
critic_head_hidden_size=512,
2428
),
2529
learn=dict(
26-
update_per_collect=24,
27-
batch_size=128,
28-
# (bool) Whether to normalize advantage. Default to False.
29-
adv_norm=False,
30-
learning_rate=0.0001,
31-
# (float) loss weight of the value network, the weight of policy network is set to 1
32-
value_weight=1.0,
33-
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
34-
entropy_weight=0.001, # [0.1, 0.01 ,0.0]
35-
clip_ratio=0.1
30+
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
31+
epoch_per_collect=4,
32+
batch_size=256,
33+
learning_rate=2.5e-4,
34+
value_weight=0.5,
35+
entropy_weight=0.01,
36+
clip_ratio=0.1,
37+
adv_norm=True,
38+
value_norm=True,
39+
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
40+
# use ignore_done=False here,
41+
# but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
42+
# for halfcheetah, the length=1000
43+
ignore_done=False,
44+
grad_clip_type='clip_norm',
45+
grad_clip_value=0.5,
3646
),
3747
collect=dict(
3848
# (int) collect n_sample data, train model n_iteration times
3949
n_sample=1024,
50+
unroll_len=1,
4051
# (float) the trade-off factor lambda to balance 1step td and mc
4152
gae_lambda=0.95,
4253
discount_factor=0.99,
4354
),
44-
eval=dict(evaluator=dict(eval_freq=1000, )),
45-
other=dict(replay_buffer=dict(
46-
replay_buffer_size=10000,
47-
max_use=3,
48-
), ),
55+
eval=dict(evaluator=dict(eval_freq=5000, )),
4956
),
5057
)
5158
main_config = EasyDict(enduro_onppo_config)

dizoo/atari/config/serial/qbert/qbert_onppo_config.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from easydict import EasyDict
22

33
qbert_onppo_config = dict(
4-
exp_name='enduro_onppo_seed0',
4+
exp_name='qbert_onppo_seed0',
55
env=dict(
6-
collector_env_num=16,
6+
collector_env_num=8,
77
evaluator_env_num=8,
88
n_evaluator_episode=8,
99
stop_value=int(1e10),
@@ -19,18 +19,20 @@
1919
obs_shape=[4, 84, 84],
2020
action_shape=6,
2121
action_space='discrete',
22-
encoder_hidden_size_list=[64, 64, 128],
23-
actor_head_hidden_size=128,
24-
critic_head_hidden_size=128,
22+
encoder_hidden_size_list=[32, 64, 64, 512],
23+
actor_head_layer_num=0,
24+
critic_head_layer_num=0,
25+
actor_head_hidden_size=512,
26+
critic_head_hidden_size=512,
2527
),
2628
learn=dict(
27-
epoch_per_collect=10,
28-
update_per_collect=1,
29-
batch_size=320,
30-
learning_rate=3e-4,
29+
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
30+
epoch_per_collect=4,
31+
batch_size=256,
32+
learning_rate=2.5e-4,
3133
value_weight=0.5,
32-
entropy_weight=0.001,
33-
clip_ratio=0.2,
34+
entropy_weight=0.01,
35+
clip_ratio=0.1,
3436
adv_norm=True,
3537
value_norm=True,
3638
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
@@ -42,7 +44,7 @@
4244
grad_clip_value=0.5,
4345
),
4446
collect=dict(
45-
n_sample=3200,
47+
n_sample=1024,
4648
unroll_len=1,
4749
discount_factor=0.99,
4850
gae_lambda=0.95,

dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
spaceinvaders_ppo_config = dict(
55
exp_name='spaceinvaders_onppo_seed0',
66
env=dict(
7-
collector_env_num=16,
7+
collector_env_num=8,
88
evaluator_env_num=8,
99
n_evaluator_episode=8,
1010
stop_value=int(1e10),
@@ -21,18 +21,20 @@
2121
obs_shape=[4, 84, 84],
2222
action_shape=6,
2323
action_space='discrete',
24-
encoder_hidden_size_list=[64, 64, 128],
25-
actor_head_hidden_size=128,
26-
critic_head_hidden_size=128,
24+
encoder_hidden_size_list=[32, 64, 64, 512],
25+
actor_head_layer_num=0,
26+
critic_head_layer_num=0,
27+
actor_head_hidden_size=512,
28+
critic_head_hidden_size=512,
2729
),
2830
learn=dict(
29-
epoch_per_collect=10,
30-
update_per_collect=1,
31-
batch_size=320,
32-
learning_rate=3e-4,
31+
lr_scheduler=dict(epoch_num=5200, min_lr_lambda=0),
32+
epoch_per_collect=4,
33+
batch_size=256,
34+
learning_rate=2.5e-4,
3335
value_weight=0.5,
34-
entropy_weight=0.001,
35-
clip_ratio=0.2,
36+
entropy_weight=0.01,
37+
clip_ratio=0.1,
3638
adv_norm=True,
3739
value_norm=True,
3840
# for onppo, when we recompute adv, we need the key done in data to split traj, so we must
@@ -44,7 +46,7 @@
4446
grad_clip_value=0.5,
4547
),
4648
collect=dict(
47-
n_sample=3200,
49+
n_sample=1024,
4850
unroll_len=1,
4951
discount_factor=0.99,
5052
gae_lambda=0.95,

0 commit comments

Comments
 (0)