-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
99 lines (74 loc) · 3.55 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from flask import Flask, flash, render_template, redirect, request, url_for
from flask import request
from werkzeug.utils import secure_filename
from model.postit import predict_postit
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
name = "mrm8488/bert-small-finetuned-squadv2"
tokenizer = AutoTokenizer.from_pretrained(name,)
model = AutoModelForQuestionAnswering.from_pretrained(name)
def answer_question(question, answer_text):
'''
Takes a `question` string and an `answer` string and tries to identify
the words within the `answer` that can answer the question. Prints them out.
'''
# tokenize the input text and get the corresponding indices
token_indices = tokenizer.encode(question, answer_text)
# Search the input_indices for the first instance of the `[SEP]` token.
sep_index = token_indices.index(tokenizer.sep_token_id)
seg_one = sep_index + 1
# The remainders lie in the second segment.
seg_two = len(token_indices) - seg_one
# Construct the list of 0s and 1s.
segment_ids = [0]*seg_one + [1]*seg_two
# get the answer for the question
start_scores, end_scores = model(torch.tensor([token_indices]), # The tokens representing our input combining question and answer.
token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from answer
print('start_scores ', start_scores)
# Find the tokens with the highest `start` and `end` scores.
answer_begin = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
# Get the string versions of the input tokens.
indices_tokens = tokenizer.convert_ids_to_tokens(token_indices)
answer = indices_tokens[answer_begin:answer_end+1]
#remove special tokens
answer = [word.replace("▁","") if word.startswith("▁") else word for word in answer] #use this when using model "twmkn9/albert-base-v2-squad2"
answer = " ".join(answer).replace("[CLS]","").replace("[SEP]","").replace(" ##","")
return answer
app = Flask(__name__)
app.secret_key = "secret key"
app.config['IMAGE_UPLOADS'] = os.path.join(os.getcwd(), 'static/uploads')
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
#boolean that checks whether filetype is allowed
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/', methods=['GET', 'POST'])
def upload_image():
if request.method == 'POST':
image = request.files["image"]
# check if the post request has the file part
if image.filename == "":
flash('No file selected')
return redirect(request.url)
if not allowed_file(image.filename):
flash('Allowed image types are -> png, jpg, jpeg, gif')
return redirect(request.url)
else:
filename = secure_filename(image.filename)
image.save(os.path.join(app.config['IMAGE_UPLOADS'] , filename))
flash("image saved")
return render_template("index.html",filename=filename)
return render_template("index.html")
@app.route('/display/<filename>')
def display_image(filename):
print('display_image filename: ' + filename)
return redirect(url_for('static', filename='uploads/' + filename), code=301)
@app.route('/display/<filename>')
def predict(filename):
text = predict_postit(filename)
return text
if __name__ == '__main__':
port = int(os.environ.get("PORT", 5000))
app.run(host='0.0.0.0', port=port)