@@ -54,12 +54,12 @@ def __init__(
54
54
``ReparameterizationHead``, and hybrid heads.
55
55
- share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder.
56
56
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
57
- the last element must match ``head_hidden_size ``.
57
+ the last element is used as the input size of ``actor_head`` and ``critic_head ``.
58
58
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
59
- to 64, it must match the last element of ``encoder_hidden_size_list`` .
59
+ to 64, it is the hidden size of the last layer of the ``actor_head`` network .
60
60
- actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
61
61
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
62
- to 64, it must match the last element of ``encoder_hidden_size_list`` .
62
+ to 64, it is the hidden size of the last layer of the ``critic_head`` network .
63
63
- critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network.
64
64
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
65
65
if ``None`` then default set it to ``nn.ReLU()``.
@@ -108,15 +108,13 @@ def new_encoder(outsize, activation):
108
108
)
109
109
110
110
if self .share_encoder :
111
- assert actor_head_hidden_size == critic_head_hidden_size , \
112
- "actor and critic network head should have same size."
113
111
if encoder :
114
112
if isinstance (encoder , torch .nn .Module ):
115
113
self .encoder = encoder
116
114
else :
117
115
raise ValueError ("illegal encoder instance." )
118
116
else :
119
- self .encoder = new_encoder (actor_head_hidden_size , activation )
117
+ self .encoder = new_encoder (encoder_hidden_size_list [ - 1 ] , activation )
120
118
else :
121
119
if encoder :
122
120
if isinstance (encoder , torch .nn .Module ):
@@ -125,25 +123,31 @@ def new_encoder(outsize, activation):
125
123
else :
126
124
raise ValueError ("illegal encoder instance." )
127
125
else :
128
- self .actor_encoder = new_encoder (actor_head_hidden_size , activation )
129
- self .critic_encoder = new_encoder (critic_head_hidden_size , activation )
126
+ self .actor_encoder = new_encoder (encoder_hidden_size_list [ - 1 ] , activation )
127
+ self .critic_encoder = new_encoder (encoder_hidden_size_list [ - 1 ] , activation )
130
128
131
129
# Head Type
132
130
self .critic_head = RegressionHead (
133
- critic_head_hidden_size , 1 , critic_head_layer_num , activation = activation , norm_type = norm_type
131
+ encoder_hidden_size_list [- 1 ],
132
+ 1 ,
133
+ critic_head_layer_num ,
134
+ activation = activation ,
135
+ norm_type = norm_type ,
136
+ hidden_size = critic_head_hidden_size
134
137
)
135
138
self .action_space = action_space
136
139
assert self .action_space in ['discrete' , 'continuous' , 'hybrid' ], self .action_space
137
140
if self .action_space == 'continuous' :
138
141
self .multi_head = False
139
142
self .actor_head = ReparameterizationHead (
140
- actor_head_hidden_size ,
143
+ encoder_hidden_size_list [ - 1 ] ,
141
144
action_shape ,
142
145
actor_head_layer_num ,
143
146
sigma_type = sigma_type ,
144
147
activation = activation ,
145
148
norm_type = norm_type ,
146
- bound_type = bound_type
149
+ bound_type = bound_type ,
150
+ hidden_size = actor_head_hidden_size ,
147
151
)
148
152
elif self .action_space == 'discrete' :
149
153
actor_head_cls = DiscreteHead
@@ -172,14 +176,15 @@ def new_encoder(outsize, activation):
172
176
action_shape .action_args_shape = squeeze (action_shape .action_args_shape )
173
177
action_shape .action_type_shape = squeeze (action_shape .action_type_shape )
174
178
actor_action_args = ReparameterizationHead (
175
- actor_head_hidden_size ,
179
+ encoder_hidden_size_list [ - 1 ] ,
176
180
action_shape .action_args_shape ,
177
181
actor_head_layer_num ,
178
182
sigma_type = sigma_type ,
179
183
fixed_sigma_value = fixed_sigma_value ,
180
184
activation = activation ,
181
185
norm_type = norm_type ,
182
186
bound_type = bound_type ,
187
+ hidden_size = actor_head_hidden_size ,
183
188
)
184
189
actor_action_type = DiscreteHead (
185
190
actor_head_hidden_size ,
0 commit comments