Skip to content

Commit 0968250

Browse files
authored
feature(lxy): add dropout layers to dqn (#712)
* add dropout layers to dqn encoder and head, also add config to control dropout * polish style * add dropout rate range in comment * add workflow time limit * polish config
1 parent 3059479 commit 0968250

File tree

7 files changed

+66
-23
lines changed

7 files changed

+66
-23
lines changed

ding/model/common/encoder.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def __init__(
141141
hidden_size_list: SequenceType,
142142
res_block: bool = False,
143143
activation: Optional[nn.Module] = nn.ReLU(),
144-
norm_type: Optional[str] = None
144+
norm_type: Optional[str] = None,
145+
dropout: Optional[float] = None
145146
) -> None:
146147
"""
147148
Overview:
@@ -153,6 +154,7 @@ def __init__(
153154
- activation (:obj:`nn.Module`): Type of activation to use in ``ResFCBlock``. Default is ``nn.ReLU()``.
154155
- norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResFCBlock`` \
155156
for more details. Default is ``None``.
157+
- dropout (:obj:`float`): Dropout rate of the dropout layer. If ``None`` then default no dropout layer.
156158
"""
157159
super(FCEncoder, self).__init__()
158160
self.obs_shape = obs_shape
@@ -162,17 +164,21 @@ def __init__(
162164
if res_block:
163165
assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts"
164166
if len(hidden_size_list) == 1:
165-
self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type)
167+
self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
166168
else:
167169
layers = []
168170
for i in range(len(hidden_size_list)):
169-
layers.append(ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type))
171+
layers.append(
172+
ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
173+
)
170174
self.main = nn.Sequential(*layers)
171175
else:
172176
layers = []
173177
for i in range(len(hidden_size_list) - 1):
174178
layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1]))
175179
layers.append(self.act)
180+
if dropout is not None:
181+
layers.append(nn.Dropout(dropout))
176182
self.main = nn.Sequential(*layers)
177183

178184
def forward(self, x: torch.Tensor) -> torch.Tensor:

ding/model/common/head.py

+10
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
layer_num: int = 1,
2929
activation: Optional[nn.Module] = nn.ReLU(),
3030
norm_type: Optional[str] = None,
31+
dropout: Optional[float] = None,
3132
noise: Optional[bool] = False,
3233
) -> None:
3334
"""
@@ -41,6 +42,7 @@ def __init__(
4142
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
4243
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
4344
for more details. Default ``None``.
45+
- dropout (:obj:`float`): The dropout rate, default set to None.
4446
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
4547
Default ``False``.
4648
"""
@@ -55,6 +57,8 @@ def __init__(
5557
layer_num,
5658
layer_fn=layer,
5759
activation=activation,
60+
use_dropout=dropout is not None,
61+
dropout_probability=dropout,
5862
norm_type=norm_type
5963
), block(hidden_size, output_size)
6064
)
@@ -800,6 +804,7 @@ def __init__(
800804
v_layer_num: Optional[int] = None,
801805
activation: Optional[nn.Module] = nn.ReLU(),
802806
norm_type: Optional[str] = None,
807+
dropout: Optional[float] = None,
803808
noise: Optional[bool] = False,
804809
) -> None:
805810
"""
@@ -814,6 +819,7 @@ def __init__(
814819
If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
815820
- norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
816821
for more details. Default ``None``.
822+
- dropout (:obj:`float`): The dropout rate of dropout layer. Default ``None``.
817823
- noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
818824
Default ``False``.
819825
"""
@@ -832,6 +838,8 @@ def __init__(
832838
a_layer_num,
833839
layer_fn=layer,
834840
activation=activation,
841+
use_dropout=dropout is not None,
842+
dropout_probability=dropout,
835843
norm_type=norm_type
836844
), block(hidden_size, output_size)
837845
)
@@ -843,6 +851,8 @@ def __init__(
843851
v_layer_num,
844852
layer_fn=layer,
845853
activation=activation,
854+
use_dropout=dropout is not None,
855+
dropout_probability=dropout,
846856
norm_type=norm_type
847857
), block(hidden_size, 1)
848858
)

ding/model/template/q_learning.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __init__(
2121
head_hidden_size: Optional[int] = None,
2222
head_layer_num: int = 1,
2323
activation: Optional[nn.Module] = nn.ReLU(),
24-
norm_type: Optional[str] = None
24+
norm_type: Optional[str] = None,
25+
dropout: Optional[float] = None
2526
) -> None:
2627
"""
2728
Overview:
@@ -35,9 +36,11 @@ def __init__(
3536
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
3637
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
3738
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
38-
if ``None`` then default set it to ``nn.ReLU()``
39+
if ``None`` then default set it to ``nn.ReLU()``.
3940
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
4041
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
42+
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
43+
if ``None`` then default no dropout layer.
4144
"""
4245
super(DQN, self).__init__()
4346
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
@@ -46,9 +49,12 @@ def __init__(
4649
head_hidden_size = encoder_hidden_size_list[-1]
4750
# FC Encoder
4851
if isinstance(obs_shape, int) or len(obs_shape) == 1:
49-
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
52+
self.encoder = FCEncoder(
53+
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, dropout=dropout
54+
)
5055
# Conv Encoder
5156
elif len(obs_shape) == 3:
57+
assert dropout is None, "dropout is not supported in ConvEncoder"
5258
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
5359
else:
5460
raise RuntimeError(
@@ -67,11 +73,17 @@ def __init__(
6773
action_shape,
6874
layer_num=head_layer_num,
6975
activation=activation,
70-
norm_type=norm_type
76+
norm_type=norm_type,
77+
dropout=dropout
7178
)
7279
else:
7380
self.head = head_cls(
74-
head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
81+
head_hidden_size,
82+
action_shape,
83+
head_layer_num,
84+
activation=activation,
85+
norm_type=norm_type,
86+
dropout=dropout
7587
)
7688

7789
def forward(self, x: torch.Tensor) -> Dict:

ding/policy/dqn.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -43,34 +43,37 @@ class DQNPolicy(Policy):
4343
| ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3]
4444
| ``_size_list`` | final dense layer. | default stride is
4545
| [4, 2 ,1]
46-
10 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
46+
10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1]
47+
| If set to ``None``
48+
| means no dropout
49+
11 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
4750
| ``per_collect`` | after collector's one collection. | from envs. Bigger val
4851
| Only valid in serial training | means more off-policy
49-
11 | ``learn.batch_`` int 64 | The number of samples of an iteration
52+
12 | ``learn.batch_`` int 64 | The number of samples of an iteration
5053
| ``size``
51-
12 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
54+
13 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
5255
| ``_rate``
53-
13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
56+
14 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
5457
| ``update_freq``
55-
14 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update
58+
15 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update
5659
| ``theta`` | Only one of [target_update_freq,
5760
| | target_theta] should be set
58-
15 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
61+
16 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
5962
| ``done`` | calculation. | fake termination env
60-
16 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
63+
17 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
6164
| call of collector. | different envs
62-
17 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample
65+
18 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample
6366
| call of collector | ,n_episode] should
6467
| | be set
65-
18 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
68+
19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
6669
| ``_len``
67-
19 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
70+
20 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
6871
| 'linear'].
69-
20 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
72+
21 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
7073
| ``start``
71-
21 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
74+
22 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
7275
| ``end``
73-
22 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
76+
23 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
7477
| ``decay`` | decay=10000 means
7578
| the exploration rate
7679
| decay from start

ding/torch_utils/network/nn_module.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def MLP(
376376
block.append(build_normalization(norm_type, dim=1)(out_channels))
377377
if activation is not None:
378378
block.append(activation)
379-
if use_dropout:
379+
if use_dropout is not None:
380380
block.append(nn.Dropout(dropout_probability))
381381

382382
# The last layer
@@ -396,6 +396,8 @@ def MLP(
396396
# The last layer uses the same activation as front layers.
397397
if activation is not None:
398398
block.append(activation)
399+
if use_dropout is not None:
400+
block.append(nn.Dropout(dropout_probability))
399401

400402
if last_linear_layer_init_zero:
401403
# Locate the last linear layer and initialize its weights and biases to 0.

ding/torch_utils/network/res_block.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,24 @@ class ResFCBlock(nn.Module):
111111
forward
112112
"""
113113

114-
def __init__(self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN'):
114+
def __init__(
115+
self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None
116+
):
115117
r"""
116118
Overview:
117119
Init the fully connected layer residual block.
118120
Arguments:
119121
- in_channels (:obj:`int`): The number of channels in the input tensor.
120122
- activation (:obj:`nn.Module`): The optional activation function.
121123
- norm_type (:obj:`str`): The type of the normalization, default set to 'BN'.
124+
- dropout (:obj:`float`): The dropout rate, default set to None.
122125
"""
123126
super(ResFCBlock, self).__init__()
124127
self.act = activation
128+
if dropout is not None:
129+
self.dropout = nn.Dropout(dropout)
130+
else:
131+
self.dropout = None
125132
self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type)
126133
self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type)
127134

@@ -138,4 +145,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
138145
x = self.fc1(x)
139146
x = self.fc2(x)
140147
x = self.act(x + identity)
148+
if self.dropout is not None:
149+
x = self.dropout(x)
141150
return x

dizoo/classic_control/cartpole/config/cartpole_dqn_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
action_shape=2,
1818
encoder_hidden_size_list=[128, 128, 64],
1919
dueling=True,
20+
dropout=0.5,
2021
),
2122
nstep=1,
2223
discount_factor=0.97,

0 commit comments

Comments
 (0)