Skip to content

Commit 33554e7

Browse files
committed
polish config
1 parent 191fe53 commit 33554e7

File tree

2 files changed

+30
-39
lines changed

2 files changed

+30
-39
lines changed

ding/policy/qtransformer.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,6 @@ def _init_learn(self) -> None:
216216
)
217217

218218
self._with_q_entropy = self._cfg.learn.with_q_entropy
219-
220-
# # Weight Init
221-
# init_w = self._cfg.learn.init_w
222-
# self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
223-
# self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
224-
# self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
225-
# self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
226-
# if self._twin_critic:
227-
# self._model.critic_head[0][-1].last.weight.data.uniform_(-init_w, init_w)
228-
# self._model.critic_head[0][-1].last.bias.data.uniform_(-init_w, init_w)
229-
# self._model.critic_head[1][-1].last.weight.data.uniform_(-init_w, init_w)
230-
# self._model.critic_head[1][-1].last.bias.data.uniform_(-init_w, init_w)
231-
# else:
232-
# self._model.critic_head[2].last.weight.data.uniform_(-init_w, init_w)
233-
# self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w)
234219
# Optimizers
235220
self._optimizer_q = Adam(
236221
self._model.parameters(),
@@ -288,8 +273,8 @@ def _init_learn(self) -> None:
288273
update_type="momentum",
289274
update_kwargs={"theta": self._cfg.learn.target_theta},
290275
)
291-
self._low = np.array(self._cfg.other["low"])
292-
self._high = np.array(self._cfg.other["high"])
276+
self._low = np.array([-1, -1, -1])
277+
self._high = np.array([1, 1, 1])
293278
self._action_bin = self._cfg.model.action_bins
294279
self._action_values = np.array(
295280
[

dizoo/d4rl/config/hopper_medium_expert_qtransformer_config.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,47 @@
55
main_config = dict(
66
exp_name="hopper_medium_expert_qtransformer_seed0",
77
env=dict(
8-
env_id='hopper-medium-expert-v0',
8+
env_id="hopper-medium-expert-v0",
99
collector_env_num=5,
1010
evaluator_env_num=8,
1111
use_act_scale=True,
1212
n_evaluator_episode=8,
1313
stop_value=6000,
1414
),
15-
1615
policy=dict(
1716
cuda=True,
18-
1917
model=dict(
20-
num_actions = 3,
21-
action_bins = 16,
22-
obs_dim = 11,
23-
dueling = False,
24-
attend_dim = 512,
18+
num_actions=3,
19+
action_bins=16,
20+
obs_dim=11,
21+
dueling=False,
22+
attend_dim=512,
2523
),
26-
2724
learn=dict(
2825
data_path=None,
2926
train_epoch=3000,
3027
batch_size=2048,
3128
learning_rate_q=3e-4,
3229
alpha=0.2,
3330
discount_factor_gamma=0.99,
34-
min_reward = 0.0,
31+
min_reward=0.0,
3532
auto_alpha=False,
3633
lagrange_thresh=-1.0,
3734
min_q_weight=5.0,
3835
),
39-
collect=dict(data_type='d4rl', ),
40-
eval=dict(evaluator=dict(eval_freq=5, )),
41-
other=dict(replay_buffer=dict(replay_buffer_size=2000000, ),
42-
low = [-1, -1, -1],
43-
high = [1, 1, 1],
44-
),
36+
collect=dict(
37+
data_type="d4rl",
38+
),
39+
eval=dict(
40+
evaluator=dict(
41+
eval_freq=5,
42+
)
43+
),
44+
other=dict(
45+
replay_buffer=dict(
46+
replay_buffer_size=2000000,
47+
),
48+
),
4549
),
4650
)
4751

@@ -50,15 +54,17 @@
5054

5155
create_config = dict(
5256
env=dict(
53-
type='d4rl',
54-
import_names=['dizoo.d4rl.envs.d4rl_env'],
57+
type="d4rl",
58+
import_names=["dizoo.d4rl.envs.d4rl_env"],
5559
),
56-
env_manager=dict(type='base'),
60+
env_manager=dict(type="base"),
5761
policy=dict(
58-
type='qtransformer',
59-
import_names=['ding.policy.qtransformer'],
62+
type="qtransformer",
63+
import_names=["ding.policy.qtransformer"],
64+
),
65+
replay_buffer=dict(
66+
type="naive",
6067
),
61-
replay_buffer=dict(type='naive', ),
6268
)
6369
create_config = EasyDict(create_config)
6470
create_config = create_config

0 commit comments

Comments
 (0)