Skip to content

Commit b06ce44

Browse files
committed
fix(nyz): fix mlp dropout if condition bug
1 parent 0968250 commit b06ce44

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

ding/policy/dqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DQNPolicy(Policy):
4141
8 | ``model.dueling`` bool True | dueling head architecture
4242
9 | ``model.encoder`` list [32, 64, | Sequence of ``hidden_size`` of | default kernel_size
4343
| ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3]
44-
| ``_size_list`` | final dense layer. | default stride is
44+
| ``_size_list`` | final dense layer. | default stride is
4545
| [4, 2 ,1]
4646
10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1]
4747
| If set to ``None``

ding/policy/mbpolicy/mbsac.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from typing import Dict, Any, List
22
from functools import partial
3-
import copy
43

54
import torch
65
from torch import Tensor
76
from torch import nn
8-
from torch.distributions import Normal, Independent, TransformedDistribution, TanhTransform
9-
from easydict import EasyDict
7+
from torch.distributions import Normal, Independent
108

119
from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat
12-
from ding.utils import POLICY_REGISTRY, deep_merge_dicts
10+
from ding.utils import POLICY_REGISTRY
1311
from ding.policy import SACPolicy
1412
from ding.rl_utils import generalized_lambda_returns
1513
from ding.policy.common_utils import default_preprocess_learn
@@ -33,11 +31,12 @@ class MBSACPolicy(SACPolicy):
3331
== ==================== ======== ============= ==================================
3432
1 ``learn._lambda`` float 0.8 | Lambda for TD-lambda return.
3533
2 ``learn.grad_clip` float 100.0 | Max norm of gradients.
36-
3 ``learn.sample_`` bool True | Whether to sample states or tra-
37-
``state`` | nsitions from env buffer.
34+
3 | ``learn.sample`` bool True | Whether to sample states or
35+
| ``_state`` | transitions from env buffer.
3836
== ==================== ======== ============= ==================================
3937
4038
.. note::
39+
4140
For other configs, please refer to ding.policy.sac.SACPolicy.
4241
"""
4342

ding/torch_utils/network/nn_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def MLP(
376376
block.append(build_normalization(norm_type, dim=1)(out_channels))
377377
if activation is not None:
378378
block.append(activation)
379-
if use_dropout is not None:
379+
if use_dropout:
380380
block.append(nn.Dropout(dropout_probability))
381381

382382
# The last layer
@@ -396,7 +396,7 @@ def MLP(
396396
# The last layer uses the same activation as front layers.
397397
if activation is not None:
398398
block.append(activation)
399-
if use_dropout is not None:
399+
if use_dropout:
400400
block.append(nn.Dropout(dropout_probability))
401401

402402
if last_linear_layer_init_zero:

0 commit comments

Comments
 (0)