Skip to content

Commit f5fed7c

Browse files
authored
polish(zym): optimize ppo continuous act (#801)
* fix (zym): update hidden size of head in VAC module * feat (zym): update ppo config to support continuous action space
1 parent d919fa5 commit f5fed7c

File tree

6 files changed

+100
-31
lines changed

6 files changed

+100
-31
lines changed

ding/model/template/vac.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def __init__(
5454
``ReparameterizationHead``, and hybrid heads.
5555
- share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder.
5656
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
57-
the last element must match ``head_hidden_size``.
57+
the last element is used as the input size of ``actor_head`` and ``critic_head``.
5858
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
59-
to 64, it must match the last element of ``encoder_hidden_size_list``.
59+
to 64, it is the hidden size of the last layer of the ``actor_head`` network.
6060
- actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
6161
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
62-
to 64, it must match the last element of ``encoder_hidden_size_list``.
62+
to 64, it is the hidden size of the last layer of the ``critic_head`` network.
6363
- critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network.
6464
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
6565
if ``None`` then default set it to ``nn.ReLU()``.
@@ -108,15 +108,13 @@ def new_encoder(outsize, activation):
108108
)
109109

110110
if self.share_encoder:
111-
assert actor_head_hidden_size == critic_head_hidden_size, \
112-
"actor and critic network head should have same size."
113111
if encoder:
114112
if isinstance(encoder, torch.nn.Module):
115113
self.encoder = encoder
116114
else:
117115
raise ValueError("illegal encoder instance.")
118116
else:
119-
self.encoder = new_encoder(actor_head_hidden_size, activation)
117+
self.encoder = new_encoder(encoder_hidden_size_list[-1], activation)
120118
else:
121119
if encoder:
122120
if isinstance(encoder, torch.nn.Module):
@@ -125,25 +123,31 @@ def new_encoder(outsize, activation):
125123
else:
126124
raise ValueError("illegal encoder instance.")
127125
else:
128-
self.actor_encoder = new_encoder(actor_head_hidden_size, activation)
129-
self.critic_encoder = new_encoder(critic_head_hidden_size, activation)
126+
self.actor_encoder = new_encoder(encoder_hidden_size_list[-1], activation)
127+
self.critic_encoder = new_encoder(encoder_hidden_size_list[-1], activation)
130128

131129
# Head Type
132130
self.critic_head = RegressionHead(
133-
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
131+
encoder_hidden_size_list[-1],
132+
1,
133+
critic_head_layer_num,
134+
activation=activation,
135+
norm_type=norm_type,
136+
hidden_size=critic_head_hidden_size
134137
)
135138
self.action_space = action_space
136139
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
137140
if self.action_space == 'continuous':
138141
self.multi_head = False
139142
self.actor_head = ReparameterizationHead(
140-
actor_head_hidden_size,
143+
encoder_hidden_size_list[-1],
141144
action_shape,
142145
actor_head_layer_num,
143146
sigma_type=sigma_type,
144147
activation=activation,
145148
norm_type=norm_type,
146-
bound_type=bound_type
149+
bound_type=bound_type,
150+
hidden_size=actor_head_hidden_size,
147151
)
148152
elif self.action_space == 'discrete':
149153
actor_head_cls = DiscreteHead
@@ -172,14 +176,15 @@ def new_encoder(outsize, activation):
172176
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
173177
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
174178
actor_action_args = ReparameterizationHead(
175-
actor_head_hidden_size,
179+
encoder_hidden_size_list[-1],
176180
action_shape.action_args_shape,
177181
actor_head_layer_num,
178182
sigma_type=sigma_type,
179183
fixed_sigma_value=fixed_sigma_value,
180184
activation=activation,
181185
norm_type=norm_type,
182186
bound_type=bound_type,
187+
hidden_size=actor_head_hidden_size,
183188
)
184189
actor_action_type = DiscreteHead(
185190
actor_head_hidden_size,

ding/policy/ppo.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class PPOPolicy(Policy):
5252
batch_size=64,
5353
# (float) The step size of gradient descent.
5454
learning_rate=3e-4,
55+
# (dict or None) The learning rate decay.
56+
# If not None, should contain key 'epoch_num' and 'min_lr_lambda'.
57+
# where 'epoch_num' is the total epoch num to decay the learning rate to min value,
58+
# 'min_lr_lambda' is the final decayed learning rate.
59+
lr_scheduler=None,
5560
# (float) The loss weight of value network, policy network weight is set to 1.
5661
value_weight=0.5,
5762
# (float) The loss weight of entropy regularization, policy network weight is set to 1.
@@ -169,6 +174,16 @@ def _init_learn(self) -> None:
169174
clip_value=self._cfg.learn.grad_clip_value
170175
)
171176

177+
# Define linear lr scheduler
178+
if self._cfg.learn.lr_scheduler is not None:
179+
epoch_num = self._cfg.learn.lr_scheduler['epoch_num']
180+
min_lr_lambda = self._cfg.learn.lr_scheduler['min_lr_lambda']
181+
182+
self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
183+
self._optimizer,
184+
lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda)
185+
)
186+
172187
self._learn_model = model_wrap(self._model, wrapper_name='base')
173188

174189
# Algorithm config
@@ -314,8 +329,13 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
314329
total_loss.backward()
315330
self._optimizer.step()
316331

332+
if self._cfg.learn.lr_scheduler is not None:
333+
cur_lr = sum(self._lr_scheduler.get_last_lr()) / len(self._lr_scheduler.get_last_lr())
334+
else:
335+
cur_lr = self._optimizer.defaults['lr']
336+
317337
return_info = {
318-
'cur_lr': self._optimizer.defaults['lr'],
338+
'cur_lr': cur_lr,
319339
'total_loss': total_loss.item(),
320340
'policy_loss': ppo_loss.policy_loss.item(),
321341
'value_loss': ppo_loss.value_loss.item(),
@@ -336,6 +356,10 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
336356
}
337357
)
338358
return_infos.append(return_info)
359+
360+
if self._cfg.learn.lr_scheduler is not None:
361+
self._lr_scheduler.step()
362+
339363
return return_infos
340364

341365
def _init_collect(self) -> None:

dizoo/mujoco/config/ant_onppo_config.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from easydict import EasyDict
2+
import torch.nn as nn
23

34
ant_ppo_config = dict(
45
exp_name="ant_onppo_seed0",
@@ -17,15 +18,24 @@
1718
recompute_adv=True,
1819
action_space='continuous',
1920
model=dict(
21+
encoder_hidden_size_list=[128, 128],
2022
action_space='continuous',
2123
obs_shape=111,
2224
action_shape=8,
25+
share_encoder=False,
26+
actor_head_layer_num=0,
27+
critic_head_layer_num=2,
28+
critic_head_hidden_size=256,
29+
actor_head_hidden_size=128,
30+
activation=nn.Tanh(),
31+
bound_type='tanh',
2332
),
2433
learn=dict(
2534
epoch_per_collect=10,
2635
update_per_collect=1,
27-
batch_size=320,
36+
batch_size=128,
2837
learning_rate=3e-4,
38+
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
2939
value_weight=0.5,
3040
entropy_weight=0.001,
3141
clip_ratio=0.2,
@@ -39,7 +49,7 @@
3949
grad_clip_value=0.5,
4050
),
4151
collect=dict(
42-
n_sample=3200,
52+
n_sample=2048,
4353
unroll_len=1,
4454
discount_factor=0.99,
4555
gae_lambda=0.95,

dizoo/mujoco/config/halfcheetah_onppo_config.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from easydict import EasyDict
2+
import torch.nn as nn
23

3-
collector_env_num = 1
4-
evaluator_env_num = 1
4+
collector_env_num = 8
5+
evaluator_env_num = 8
56
halfcheetah_ppo_config = dict(
67
exp_name='halfcheetah_onppo_seed0',
78
env=dict(
@@ -10,23 +11,32 @@
1011
norm_reward=dict(use_norm=False, ),
1112
collector_env_num=collector_env_num,
1213
evaluator_env_num=evaluator_env_num,
13-
n_evaluator_episode=1,
14+
n_evaluator_episode=8,
1415
stop_value=12000,
1516
),
1617
policy=dict(
1718
cuda=True,
1819
recompute_adv=True,
1920
action_space='continuous',
2021
model=dict(
22+
encoder_hidden_size_list=[128, 128],
2123
action_space='continuous',
24+
share_encoder=False,
25+
actor_head_layer_num=0,
26+
critic_head_layer_num=2,
27+
critic_head_hidden_size=256,
28+
actor_head_hidden_size=128,
2229
obs_shape=17,
2330
action_shape=6,
31+
activation=nn.Tanh(),
32+
bound_type='tanh',
2433
),
2534
learn=dict(
2635
epoch_per_collect=10,
2736
update_per_collect=1,
28-
batch_size=320,
37+
batch_size=128,
2938
learning_rate=3e-4,
39+
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
3040
value_weight=0.5,
3141
entropy_weight=0.001,
3242
clip_ratio=0.2,
@@ -42,12 +52,12 @@
4252
),
4353
collect=dict(
4454
collector_env_num=collector_env_num,
45-
n_sample=3200,
55+
n_sample=2048,
4656
unroll_len=1,
4757
discount_factor=0.99,
4858
gae_lambda=0.95,
4959
),
50-
eval=dict(evaluator=dict(eval_freq=500, )),
60+
eval=dict(evaluator=dict(eval_freq=5000, )),
5161
),
5262
)
5363
halfcheetah_ppo_config = EasyDict(halfcheetah_ppo_config)

dizoo/mujoco/config/hopper_onppo_config.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from easydict import EasyDict
2+
import torch.nn as nn
23

34
hopper_onppo_config = dict(
45
exp_name='hopper_onppo_seed0',
@@ -12,19 +13,28 @@
1213
stop_value=4000,
1314
),
1415
policy=dict(
15-
cuda=True,
16+
cuda=False,
1617
recompute_adv=True,
1718
action_space='continuous',
1819
model=dict(
20+
encoder_hidden_size_list=[128, 128],
1921
obs_shape=11,
2022
action_shape=3,
2123
action_space='continuous',
24+
share_encoder=False,
25+
actor_head_layer_num=0,
26+
critic_head_layer_num=2,
27+
critic_head_hidden_size=256,
28+
actor_head_hidden_size=128,
29+
activation=nn.Tanh(),
30+
bound_type='tanh',
2231
),
2332
learn=dict(
2433
epoch_per_collect=10,
2534
update_per_collect=1,
26-
batch_size=320,
35+
batch_size=128,
2736
learning_rate=3e-4,
37+
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
2838
value_weight=0.5,
2939
entropy_weight=0.001,
3040
clip_ratio=0.2,
@@ -39,12 +49,12 @@
3949
grad_clip_value=0.5,
4050
),
4151
collect=dict(
42-
n_sample=3200,
52+
n_sample=2048,
4353
unroll_len=1,
4454
discount_factor=0.99,
4555
gae_lambda=0.95,
4656
),
47-
eval=dict(evaluator=dict(eval_freq=500, )),
57+
eval=dict(evaluator=dict(eval_freq=5000, )),
4858
),
4959
)
5060
hopper_onppo_config = EasyDict(hopper_onppo_config)

dizoo/mujoco/config/walker2d_onppo_config.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from easydict import EasyDict
2+
import torch.nn as nn
23

3-
collector_env_num = 1
4-
evaluator_env_num = 1
4+
collector_env_num = 8
5+
evaluator_env_num = 8
56
walker2d_onppo_config = dict(
67
exp_name='walker2d_onppo_seed0',
78
env=dict(
@@ -10,23 +11,32 @@
1011
norm_reward=dict(use_norm=False, ),
1112
collector_env_num=collector_env_num,
1213
evaluator_env_num=evaluator_env_num,
13-
n_evaluator_episode=10,
14+
n_evaluator_episode=8,
1415
stop_value=6000,
1516
),
1617
policy=dict(
1718
cuda=True,
1819
recompute_adv=True,
1920
action_space='continuous',
2021
model=dict(
22+
encoder_hidden_size_list=[128, 128],
2123
action_space='continuous',
24+
share_encoder=False,
25+
actor_head_layer_num=0,
26+
critic_head_layer_num=2,
27+
critic_head_hidden_size=256,
28+
actor_head_hidden_size=128,
2229
obs_shape=17,
2330
action_shape=6,
31+
activation=nn.Tanh(),
32+
bound_type='tanh',
2433
),
2534
learn=dict(
2635
epoch_per_collect=10,
2736
update_per_collect=1,
28-
batch_size=320,
37+
batch_size=128,
2938
learning_rate=3e-4,
39+
lr_scheduler=dict(epoch_num=1500, min_lr_lambda=0),
3040
value_weight=0.5,
3141
entropy_weight=0.001,
3242
clip_ratio=0.2,
@@ -43,12 +53,12 @@
4353
),
4454
collect=dict(
4555
collector_env_num=collector_env_num,
46-
n_sample=3200,
56+
n_sample=2048,
4757
unroll_len=1,
4858
discount_factor=0.99,
4959
gae_lambda=0.95,
5060
),
51-
eval=dict(evaluator=dict(eval_freq=500, )),
61+
eval=dict(evaluator=dict(eval_freq=5000, )),
5262
),
5363
)
5464
walker2d_onppo_config = EasyDict(walker2d_onppo_config)

0 commit comments

Comments
 (0)