Skip to content

Commit 13a6d45

Browse files
committed
fix(nyz): fix gtrxl compatibility bug (#796)
1 parent b2aab8d commit 13a6d45

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

ding/model/template/q_learning.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1118,8 +1118,9 @@ def __init__(
11181118
gru_bias=gru_bias,
11191119
)
11201120

1121+
# for vector obs, use Identity Encoder, i.e. pass
11211122
if isinstance(obs_shape, int) or len(obs_shape) == 1:
1122-
raise NotImplementedError("not support obs_shape for pre-defined encoder: {}".format(obs_shape))
1123+
pass
11231124
# replace the embedding layer of Transformer with Conv Encoder
11241125
elif len(obs_shape) == 3:
11251126
assert encoder_hidden_size_list[-1] == hidden_size

ding/policy/r2d2_gtrxl.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,9 @@ class R2D2GTrXLPolicy(Policy):
5555
| ``done`` | calculation. | fake termination env
5656
15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
5757
| call of collector. | different envs
58-
16 | ``collect.unroll`` int 25 | unroll length of an iteration | unroll_len>1
58+
16 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1
5959
| ``_len``
60-
17 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1
61-
| ``_len``
62-
18 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the |
60+
17 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the |
6361
| ``memory`` | memory before each training iteration. |
6462
== ==================== ======== ============== ======================================== =======================
6563
"""
@@ -81,7 +79,7 @@ class R2D2GTrXLPolicy(Policy):
8179
discount_factor=0.99,
8280
# (int) N-step reward for target q_value estimation
8381
nstep=5,
84-
# how many steps to use as burnin
82+
# (int) How many steps to use in burnin phase
8583
burnin_step=1,
8684
# (int) trajectory length
8785
unroll_len=25,
@@ -158,7 +156,7 @@ def _init_learn(self) -> None:
158156
self._seq_len = self._cfg.seq_len
159157
self._value_rescale = self._cfg.learn.value_rescale
160158
self._init_memory = self._cfg.learn.init_memory
161-
assert self._init_memory in ['zero', 'old']
159+
assert self._init_memory in ['zero', 'old'], self._init_memory
162160

163161
self._target_model = copy.deepcopy(self._model)
164162

@@ -352,7 +350,6 @@ def _init_collect(self) -> None:
352350
Collect mode init method. Called by ``self.__init__``.
353351
Init unroll length and sequence len, collect model.
354352
"""
355-
assert 'unroll_len' not in self._cfg.collect, "Use default unroll_len"
356353
self._nstep = self._cfg.nstep
357354
self._gamma = self._cfg.discount_factor
358355
self._unroll_len = self._cfg.unroll_len

dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
obs_shape=4,
1919
action_shape=2,
2020
memory_len=5, # length of transformer memory (can be 0)
21-
hidden_size=256,
21+
hidden_size=64,
2222
gru_bias=2.,
2323
att_layer_num=3,
2424
dropout=0.,
25-
att_head_num=8,
25+
att_head_num=4,
2626
),
2727
discount_factor=0.99,
2828
nstep=3,
@@ -31,7 +31,7 @@
3131
seq_len=8, # transformer input segment
3232
# training sequence: unroll_len - burnin_step - nstep
3333
learn=dict(
34-
update_per_collect=8,
34+
update_per_collect=16,
3535
batch_size=64,
3636
learning_rate=0.0005,
3737
target_update_freq=500,

0 commit comments

Comments
 (0)