@@ -54,7 +54,7 @@ class PromptAWRPolicy(Policy):
54
54
# (float) Coefficient that controls the exp scale in awr algorithm.
55
55
beta = 1.0 ,
56
56
# (float) Weight of entropy regularization in the loss function.
57
- entropy_weight = 0.01 ,
57
+ entropy_weight = 0.001 ,
58
58
# (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped.
59
59
adv_range = (- 0.5 , 0.5 ),
60
60
# (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time
@@ -82,7 +82,7 @@ class PromptAWRPolicy(Policy):
82
82
def default_model (self ) -> Tuple [str , List [str ]]:
83
83
"""
84
84
Overview:
85
- Returns the default model configuration used by the A2C algorithm. ``__init__`` method will \
85
+ Returns the default model configuration used by the AWR algorithm. ``__init__`` method will \
86
86
automatically call this method to get the default model setting and create model.
87
87
88
88
Returns:
@@ -94,7 +94,7 @@ def default_model(self) -> Tuple[str, List[str]]:
94
94
def _init_learn (self ) -> None :
95
95
"""
96
96
Overview:
97
- Initialize the learn mode of policy, including related attributes and modules. For A2C , it mainly \
97
+ Initialize the learn mode of policy, including related attributes and modules. For AWR , it mainly \
98
98
contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm
99
99
and grad_norm, and main model. \
100
100
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
@@ -141,26 +141,33 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
141
141
142
142
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
143
143
train_samples , cand_samples = batch ["obs" ]["train_sample" ], batch ["obs" ]["candidate_samples" ]
144
- for ii in range (len (cand_samples )):
145
- cand_samples [ii ] = cand_samples [ii ][0 ]
144
+ for cand_n in range (len (cand_samples )):
145
+ cand_samples [cand_n ] = cand_samples [cand_n ][0 ]
146
146
output = self ._learn_model .forward (train_samples , cand_samples , mode = 'compute_actor_critic' )
147
147
return_ = batch ['return' ]
148
148
149
- # calculate PG loss
150
- real_act = batch ['action' ] # shape: (B, shot_number)
149
+ # Calculate AWR loss
150
+ real_act = batch ['action' ]
151
+
152
+ # Ensure the shape of real_act is: (B, shot_number)
151
153
if len (real_act .shape ) == 1 :
152
154
real_act = real_act .unsqueeze (- 1 )
153
- # Calculate loss.
155
+
156
+ # Calculate different parts of loss.
154
157
total_policy_loss , total_entropy_loss , total_value_loss = 0 , 0 , 0
155
- for ii in range (self ._cfg .shot_number ):
156
- log_prob = output ['dist' ].log_prob (real_act [:, ii ])
158
+ for shot_n in range (self ._cfg .shot_number ):
159
+ log_prob = output ['dist' ].log_prob (real_act [:, shot_n ])
160
+ # Clamp the adv for better stability.
157
161
adv = torch .clamp (
158
162
return_ - batch ['value' ], min = self ._cfg .learn .norm_range [0 ], max = self ._cfg .learn .norm_range [1 ]
159
163
)
164
+ # The policy loss for AWR algorithm.
160
165
policy_loss = - (log_prob * torch .exp (adv / self ._cfg .learn .beta )).mean ()
161
166
total_policy_loss += policy_loss
167
+ # The value loss for AWR algorithm.
162
168
value_loss = ((return_ - output ['value' ]) ** 2 ).mean ()
163
169
total_value_loss += value_loss
170
+ # The entropy loss for AWR algorithm.
164
171
total_entropy_loss += - self ._cfg .learn .entropy_weight * output ['dist' ].entropy ().mean ()
165
172
total_loss = total_entropy_loss + total_policy_loss + total_value_loss
166
173
0 commit comments