Skip to content

Commit b9c7df7

Browse files
author
‘whl’
committed
polish
1 parent 7c3913d commit b9c7df7

File tree

3 files changed

+27
-18
lines changed

3 files changed

+27
-18
lines changed

ding/model/template/language_transformer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class LanguageTransformer(nn.Module):
1515
"""
1616
Overview:
1717
The LanguageTransformer network. Download a pre-trained language model and add head on it.
18+
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
19+
for obtaining the embedding of the whole sentence.
1820
Interfaces:
1921
``__init__``, ``forward``
2022
"""
@@ -35,12 +37,12 @@ def __init__(
3537
Arguments:
3638
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
3739
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
38-
``False``.
40+
``False``.
3941
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
4042
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
41-
defaults to be ``True``.
43+
defaults to be ``True``.
4244
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
43-
correspond to the model you use. For bert-base-uncased, this value is 768.
45+
correspond to the model you use. For bert-base-uncased, this value is 768.
4446
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
4547
"""
4648
super().__init__()

ding/policy/prompt_awr.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class PromptAWRPolicy(Policy):
5454
# (float) Coefficient that controls the exp scale in awr algorithm.
5555
beta=1.0,
5656
# (float) Weight of entropy regularization in the loss function.
57-
entropy_weight=0.01,
57+
entropy_weight=0.001,
5858
# (Tuple[float, float]) The range of adv. Value that exceeds this range will be clipped.
5959
adv_range=(-0.5, 0.5),
6060
# (bool) If set to True, the 'done' signals that indicate the end of an episode due to environment time
@@ -82,7 +82,7 @@ class PromptAWRPolicy(Policy):
8282
def default_model(self) -> Tuple[str, List[str]]:
8383
"""
8484
Overview:
85-
Returns the default model configuration used by the A2C algorithm. ``__init__`` method will \
85+
Returns the default model configuration used by the AWR algorithm. ``__init__`` method will \
8686
automatically call this method to get the default model setting and create model.
8787
8888
Returns:
@@ -94,7 +94,7 @@ def default_model(self) -> Tuple[str, List[str]]:
9494
def _init_learn(self) -> None:
9595
"""
9696
Overview:
97-
Initialize the learn mode of policy, including related attributes and modules. For A2C, it mainly \
97+
Initialize the learn mode of policy, including related attributes and modules. For AWR, it mainly \
9898
contains optimizer, algorithm-specific arguments such as value_weight, entropy_weight, adv_norm
9999
and grad_norm, and main model. \
100100
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
@@ -141,26 +141,33 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
141141

142142
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
143143
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
144-
for ii in range(len(cand_samples)):
145-
cand_samples[ii] = cand_samples[ii][0]
144+
for cand_n in range(len(cand_samples)):
145+
cand_samples[cand_n] = cand_samples[cand_n][0]
146146
output = self._learn_model.forward(train_samples, cand_samples, mode='compute_actor_critic')
147147
return_ = batch['return']
148148

149-
# calculate PG loss
150-
real_act = batch['action'] # shape: (B, shot_number)
149+
# Calculate AWR loss
150+
real_act = batch['action']
151+
152+
# Ensure the shape of real_act is: (B, shot_number)
151153
if len(real_act.shape) == 1:
152154
real_act = real_act.unsqueeze(-1)
153-
# Calculate loss.
155+
156+
# Calculate different parts of loss.
154157
total_policy_loss, total_entropy_loss, total_value_loss = 0, 0, 0
155-
for ii in range(self._cfg.shot_number):
156-
log_prob = output['dist'].log_prob(real_act[:, ii])
158+
for shot_n in range(self._cfg.shot_number):
159+
log_prob = output['dist'].log_prob(real_act[:, shot_n])
160+
# Clamp the adv for better stability.
157161
adv = torch.clamp(
158162
return_ - batch['value'], min=self._cfg.learn.norm_range[0], max=self._cfg.learn.norm_range[1]
159163
)
164+
# The policy loss for AWR algorithm.
160165
policy_loss = -(log_prob * torch.exp(adv / self._cfg.learn.beta)).mean()
161166
total_policy_loss += policy_loss
167+
# The value loss for AWR algorithm.
162168
value_loss = ((return_ - output['value']) ** 2).mean()
163169
total_value_loss += value_loss
170+
# The entropy loss for AWR algorithm.
164171
total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
165172
total_loss = total_entropy_loss + total_policy_loss + total_value_loss
166173

dizoo/tabmwp/config/tabmwp_awr_config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from easydict import EasyDict
22

3-
tabmwp_prompt_pg_config = dict(
4-
exp_name='tabmwp_prompt_pg_seed0',
3+
tabmwp_prompt_awr_config = dict(
4+
exp_name='tabmwp_prompt_awr_seed0',
55
env=dict(
66
collector_env_num=1,
77
evaluator_env_num=1,
@@ -48,9 +48,9 @@
4848
eval=dict(evaluator=dict(eval_freq=500, )),
4949
),
5050
)
51-
main_config = EasyDict(tabmwp_prompt_pg_config)
51+
main_config = EasyDict(tabmwp_prompt_awr_config)
5252

53-
tabmwp_prompt_pg_config = dict(
53+
tabmwp_prompt_awr_config = dict(
5454
env=dict(
5555
type='tabmwp',
5656
import_names=['dizoo.tabmwp.envs.tabmwp_env'],
@@ -59,7 +59,7 @@
5959
policy=dict(type='prompt_awr'),
6060
replay_buffer=dict(type='naive'),
6161
)
62-
create_config = EasyDict(tabmwp_prompt_pg_config)
62+
create_config = EasyDict(tabmwp_prompt_awr_config)
6363

6464
if __name__ == '__main__':
6565
from ding.entry import serial_pipeline_onpolicy

0 commit comments

Comments
 (0)