-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprior_model.py
58 lines (49 loc) · 2.11 KB
/
prior_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from sentence_transformers import SentenceTransformer
import torch
from typing import List
import os
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from utils import gather_log_probabilities
class PriorModel:
def __init__(self, device, mlp_path=None, hidden_size=384):
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').to(device)
self.mlp = torch.nn.Linear(hidden_size, 1).to(device)
if mlp_path is not None:
self.mlp.load_state_dict(torch.load(mlp_path))
self.device = device
def forward(self, sentences: List[str]):
vector = self.model.encode(sentences)
vector = torch.tensor(vector).to(self.device)
prior = self.mlp(vector)
prior = 1. / (1. + torch.exp(-prior))
return prior
def save_mlp(self, path):
torch.save(self.mlp.state_dict(), path)
class PriorModelCodegen:
def __init__(self, device, normalize_len=True, path="/media/george/Projects/Labs/CogSci_labs/models/codegen-350M-mono"):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
self.normalize_len = normalize_len
self.model.to(device)
self.device = device
def forward(self, sentence: List[str]):
input_ids = self.tokenizer(sentence, return_tensors="pt").input_ids.to(self.device)
with torch.no_grad():
output:CausalLMOutputWithPast = self.model(input_ids)
logits = output.logits
labels = input_ids.clone().detach()
log_prob = gather_log_probabilities(logits[:, :-1], labels[:, 1:])
seq_len = labels.size(1)
if self.normalize_len:
prob = log_prob.sum(dim=-1) / seq_len
else:
prob = log_prob.sum(dim=-1)
prob = torch.exp(prob)
return prob
def test():
model = PriorModel()
print(model.forward(['I am a sentence', 'I am another sentence']))
if __name__ == '__main__':
test()