-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
75 lines (51 loc) · 2.47 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
from flask import Flask, render_template
from flask import request
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
name = "mrm8488/bert-small-finetuned-squadv2"
tokenizer = AutoTokenizer.from_pretrained(name,)
model = AutoModelForQuestionAnswering.from_pretrained(name)
def answer_question(question, answer_text):
'''
Takes a `question` string and an `answer` string and tries to identify
the words within the `answer` that can answer the question. Prints them out.
'''
# tokenize the input text and get the corresponding indices
token_indices = tokenizer.encode(question, answer_text)
# Search the input_indices for the first instance of the `[SEP]` token.
sep_index = token_indices.index(tokenizer.sep_token_id)
seg_one = sep_index + 1
# The remainders lie in the second segment.
seg_two = len(token_indices) - seg_one
# Construct the list of 0s and 1s.
segment_ids = [0]*seg_one + [1]*seg_two
# get the answer for the question
start_scores, end_scores = model(torch.tensor([token_indices]), # The tokens representing our input combining question and answer.
token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from answer
# Find the tokens with the highest `start` and `end` scores.
answer_begin = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Get the string versions of the input tokens.
indices_tokens = tokenizer.convert_ids_to_tokens(token_indices)
answer = indices_tokens[answer_begin:answer_end+1]
#remove special tokens
answer = [word.replace("▁","") if word.startswith("▁") else word for word in answer] #use this when using model "twmkn9/albert-base-v2-squad2"
answer = " ".join(answer).replace("[CLS]","").replace("[SEP]","").replace(" ##","")
return answer
app = Flask(__name__)
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
form = request.form
result = []
bert_abstract = form['paragraph']
question = form['question']
result.append(form['question'])
result.append(answer_question(question, bert_abstract))
result.append(form['paragraph'])
return render_template("index.html",result = result)
return render_template("index.html")
if __name__ == '__main__':
port = int(os.environ.get("PORT", 5000))
app.run(host='0.0.0.0', port=port)