Skip to content

Commit 3898386

Browse files
authored
feature(whl): add AWR algorithm (#828)
* init commit * reformat * polish * polish readme * reformat * polish
1 parent 6ae1396 commit 3898386

File tree

8 files changed

+421
-18
lines changed

8 files changed

+421
-18
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
5858
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
5959
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
6060
- Exploration algorithms: HER, RND, ICM, NGU
61-
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
61+
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
6262
- Other algorithms: such as PER, PLR, PCGrad
6363
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
6464
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
@@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
283283
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
284284
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
285285
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
286+
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |
286287

287288
</details>
288289

ding/model/template/language_transformer.py

+39-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict
1+
from typing import List, Dict, Optional
22
import torch
33
from torch import nn
44

@@ -15,31 +15,44 @@ class LanguageTransformer(nn.Module):
1515
"""
1616
Overview:
1717
The LanguageTransformer network. Download a pre-trained language model and add head on it.
18+
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
19+
for obtaining the embedding of the whole sentence.
1820
Interfaces:
1921
``__init__``, ``forward``
2022
"""
23+
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
2124

2225
def __init__(
2326
self,
2427
model_name: str = "bert-base-uncased",
2528
add_linear: bool = False,
2629
embedding_size: int = 128,
27-
freeze_encoder: bool = True
30+
freeze_encoder: bool = True,
31+
hidden_dim: int = 768,
32+
norm_embedding: bool = False
2833
) -> None:
2934
"""
3035
Overview:
3136
Init the LanguageTransformer Model according to input arguments.
3237
Arguments:
3338
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
3439
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
35-
``False``.
40+
``False``.
3641
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
3742
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
38-
defaults to be ``True``.
43+
defaults to be ``True``.
44+
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
45+
correspond to the model you use. For bert-base-uncased, this value is 768.
46+
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
3947
"""
4048
super().__init__()
4149
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
4250
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
51+
in_channel = hidden_dim if not add_linear else embedding_size
52+
self.value_head = nn.Linear(in_channel, 1)
53+
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(
54+
normalized_shape=in_channel, elementwise_affine=False
55+
)
4356

4457
# Freeze transformer encoder and only train the linear layer
4558
if freeze_encoder:
@@ -49,9 +62,7 @@ def __init__(
4962
if add_linear:
5063
# Add a small, adjustable linear layer on top of language model tuned through RL
5164
self.embedding_size = embedding_size
52-
self.linear = nn.Linear(
53-
self.model.config.hidden_size, embedding_size
54-
) # 768 for bert-base-uncased, distilbert-base-uncased
65+
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
5566
else:
5667
self.linear = None
5768

@@ -66,19 +77,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
6677
last_hidden_states = output.hidden_states[-1]
6778
# Get [CLS] hidden states
6879
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
80+
sentence_embedding = self.norm(sentence_embedding)
6981

7082
if self.linear:
7183
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size
7284

7385
return sentence_embedding
7486

75-
def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
87+
def forward(
88+
self,
89+
train_samples: List[str],
90+
candidate_samples: Optional[List[str]] = None,
91+
mode: str = 'compute_actor'
92+
) -> Dict:
7693
"""
7794
Overview:
7895
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
96+
Different ``mode`` will forward with different network modules to get different outputs.
7997
Arguments:
8098
- train_samples (:obj:`List[str]`): One list of strings.
81-
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
99+
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
100+
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
82101
Returns:
83102
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
84103
corresponding ``torch.distributions.Categorical`` object.
@@ -96,7 +115,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic
96115
>>> scores = model(ctxt_list, cands_list)
97116
>>> assert scores.shape == (1, 3)
98117
"""
118+
assert mode in self.mode
99119
prompt_embedding = self._calc_embedding(train_samples)
100-
cands_embedding = self._calc_embedding(candidate_samples)
101-
scores = torch.mm(prompt_embedding, cands_embedding.t())
102-
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
120+
121+
res_dict = {}
122+
if mode in ['compute_actor', 'compute_actor_critic']:
123+
cands_embedding = self._calc_embedding(candidate_samples)
124+
scores = torch.mm(prompt_embedding, cands_embedding.t())
125+
res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores})
126+
if mode in ['compute_critic', 'compute_actor_critic']:
127+
value = self.value_head(prompt_embedding)
128+
res_dict.update({'value': value})
129+
return res_dict

ding/model/template/tests/test_language_transformer.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,33 @@ def check_model(self):
1717
cands_list = [problems[pid] for pid in cand_pids]
1818

1919
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
20-
scores = model(ctxt_list, cands_list)
21-
assert scores.shape == (1, 3)
20+
output = model(ctxt_list, cands_list, mode='compute_actor')
21+
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
22+
assert output['logit'].shape == (1, 3)
2223

23-
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
24-
scores = model(ctxt_list, cands_list)
25-
assert scores.shape == (1, 3)
24+
output = model(ctxt_list, cands_list, mode='compute_critic')
25+
assert 'value' in output.keys() and len(output.keys()) == 1
26+
assert output['value'].shape == (1, )
27+
28+
output = model(ctxt_list, cands_list, mode='compute_critic')
29+
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
30+
output.keys()
31+
) == 3
32+
assert output['value'].shape == (1, )
33+
assert output['logit'].shape == (1, 3)
34+
35+
model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True)
36+
output = model(ctxt_list, cands_list, mode='compute_actor')
37+
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
38+
assert output['logit'].shape == (1, 3)
39+
40+
output = model(ctxt_list, cands_list, mode='compute_critic')
41+
assert 'value' in output.keys() and len(output.keys()) == 1
42+
assert output['value'].shape == (1, )
43+
44+
output = model(ctxt_list, cands_list, mode='compute_critic')
45+
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
46+
output.keys()
47+
) == 3
48+
assert output['value'].shape == (1, )
49+
assert output['logit'].shape == (1, 3)

ding/policy/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@
5656
# new-type policy
5757
from .ppof import PPOFPolicy
5858
from .prompt_pg import PromptPGPolicy
59+
from .prompt_awr import PromptAWRPolicy
5960
from .happo import HAPPOPolicy

ding/policy/command_mode_policy_instance.py

+6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .prompt_pg import PromptPGPolicy
5353
from .plan_diffuser import PDPolicy
5454
from .happo import HAPPOPolicy
55+
from .prompt_awr import PromptAWRPolicy
5556

5657

5758
class EpsCommandModePolicy(CommandModePolicy):
@@ -455,3 +456,8 @@ def _get_setting_eval(self, command_info: dict) -> dict:
455456
@POLICY_REGISTRY.register('prompt_pg_command')
456457
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
457458
pass
459+
460+
461+
@POLICY_REGISTRY.register('prompt_awr_command')
462+
class PromptAWRCommandModePolicy(PromptAWRPolicy, DummyCommandModePolicy):
463+
pass

0 commit comments

Comments
 (0)