1
- from typing import List , Dict
1
+ from typing import List , Dict , Optional
2
2
import torch
3
3
from torch import nn
4
4
@@ -15,31 +15,44 @@ class LanguageTransformer(nn.Module):
15
15
"""
16
16
Overview:
17
17
The LanguageTransformer network. Download a pre-trained language model and add head on it.
18
+ In the default case, we use BERT model as the text encoder, whose bi-directional character is good
19
+ for obtaining the embedding of the whole sentence.
18
20
Interfaces:
19
21
``__init__``, ``forward``
20
22
"""
23
+ mode = ['compute_actor' , 'compute_critic' , 'compute_actor_critic' ]
21
24
22
25
def __init__ (
23
26
self ,
24
27
model_name : str = "bert-base-uncased" ,
25
28
add_linear : bool = False ,
26
29
embedding_size : int = 128 ,
27
- freeze_encoder : bool = True
30
+ freeze_encoder : bool = True ,
31
+ hidden_dim : int = 768 ,
32
+ norm_embedding : bool = False
28
33
) -> None :
29
34
"""
30
35
Overview:
31
36
Init the LanguageTransformer Model according to input arguments.
32
37
Arguments:
33
38
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
34
39
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
35
- ``False``.
40
+ ``False``.
36
41
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
37
42
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
38
- defaults to be ``True``.
43
+ defaults to be ``True``.
44
+ - hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
45
+ correspond to the model you use. For bert-base-uncased, this value is 768.
46
+ - norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
39
47
"""
40
48
super ().__init__ ()
41
49
self .tokenizer = AutoTokenizer .from_pretrained (model_name )
42
50
self .model = AutoModelForTokenClassification .from_pretrained (model_name )
51
+ in_channel = hidden_dim if not add_linear else embedding_size
52
+ self .value_head = nn .Linear (in_channel , 1 )
53
+ self .norm = nn .Identity () if not norm_embedding else nn .LayerNorm (
54
+ normalized_shape = in_channel , elementwise_affine = False
55
+ )
43
56
44
57
# Freeze transformer encoder and only train the linear layer
45
58
if freeze_encoder :
@@ -49,9 +62,7 @@ def __init__(
49
62
if add_linear :
50
63
# Add a small, adjustable linear layer on top of language model tuned through RL
51
64
self .embedding_size = embedding_size
52
- self .linear = nn .Linear (
53
- self .model .config .hidden_size , embedding_size
54
- ) # 768 for bert-base-uncased, distilbert-base-uncased
65
+ self .linear = nn .Linear (self .model .config .hidden_size , embedding_size )
55
66
else :
56
67
self .linear = None
57
68
@@ -66,19 +77,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
66
77
last_hidden_states = output .hidden_states [- 1 ]
67
78
# Get [CLS] hidden states
68
79
sentence_embedding = last_hidden_states [:, 0 , :] # len(input_list) x hidden_size
80
+ sentence_embedding = self .norm (sentence_embedding )
69
81
70
82
if self .linear :
71
83
sentence_embedding = self .linear (sentence_embedding ) # len(input_list) x embedding_size
72
84
73
85
return sentence_embedding
74
86
75
- def forward (self , train_samples : List [str ], candidate_samples : List [str ]) -> Dict :
87
+ def forward (
88
+ self ,
89
+ train_samples : List [str ],
90
+ candidate_samples : Optional [List [str ]] = None ,
91
+ mode : str = 'compute_actor'
92
+ ) -> Dict :
76
93
"""
77
94
Overview:
78
95
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
96
+ Different ``mode`` will forward with different network modules to get different outputs.
79
97
Arguments:
80
98
- train_samples (:obj:`List[str]`): One list of strings.
81
- - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
99
+ - candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
100
+ - - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
82
101
Returns:
83
102
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
84
103
corresponding ``torch.distributions.Categorical`` object.
@@ -96,7 +115,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic
96
115
>>> scores = model(ctxt_list, cands_list)
97
116
>>> assert scores.shape == (1, 3)
98
117
"""
118
+ assert mode in self .mode
99
119
prompt_embedding = self ._calc_embedding (train_samples )
100
- cands_embedding = self ._calc_embedding (candidate_samples )
101
- scores = torch .mm (prompt_embedding , cands_embedding .t ())
102
- return {'dist' : torch .distributions .Categorical (logits = scores ), 'logit' : scores }
120
+
121
+ res_dict = {}
122
+ if mode in ['compute_actor' , 'compute_actor_critic' ]:
123
+ cands_embedding = self ._calc_embedding (candidate_samples )
124
+ scores = torch .mm (prompt_embedding , cands_embedding .t ())
125
+ res_dict .update ({'dist' : torch .distributions .Categorical (logits = scores ), 'logit' : scores })
126
+ if mode in ['compute_critic' , 'compute_actor_critic' ]:
127
+ value = self .value_head (prompt_embedding )
128
+ res_dict .update ({'value' : value })
129
+ return res_dict
0 commit comments