-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
117 lines (96 loc) · 3.54 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pickle
import warnings
import jsonify
import numpy as np
import requests
import torch
from flask import Flask, Markup, render_template, request
from transformers import BertForTokenClassification, logging
warnings.simplefilter(action="ignore", category=Warning)
logging.set_verbosity(logging.ERROR)
print("[+] this might take a few seconds or minutes...")
app = Flask(__name__)
device = torch.device("cpu")
tokenizer = pickle.load(open("tokenizer.pkl", "rb"))
tag_values = pickle.load(open("tag_values.pkl", "rb"))
model = BertForTokenClassification.from_pretrained(
"bert-base-german-cased",
num_labels=len(tag_values),
output_attentions=False,
output_hidden_states=False,
)
model.load_state_dict(torch.load("model.pt", map_location=device))
classes = {
"AN": "Lawyer",
"EUN": "European legal norm",
"GRT": "Court",
"GS": "Law",
"INN": "Institution",
"LD": "Country",
"LDS": "Landscape",
"LIT": "Legal literature",
"MRK": "Brand",
"ORG": "Organization",
"PER": "Person",
"RR": "Judge",
"RS": "Court decision",
"ST": "City",
"STR": "Street",
"UN": "Company",
"VO": "Ordinance",
"VS": "Regulation",
"VT": "Contract",
}
@app.route("/", methods=["GET"])
def Home():
return render_template("index.html")
@app.route("/analyze", methods=["POST"])
def analyze():
if request.method == "POST":
test_sentence = str(request.form["testsentence"])
tokenized_sentence = tokenizer.encode(test_sentence)
input_ids = torch.tensor([tokenized_sentence])
with torch.no_grad():
output = model(input_ids)
label_indices = np.argmax(output[0].numpy(), axis=2)
tokens = tokenizer.convert_ids_to_tokens(input_ids.numpy()[0])
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
if token.startswith("##"):
new_tokens[-1] = new_tokens[-1] + token[2:]
else:
new_labels.append(tag_values[label_idx])
new_tokens.append(token)
to_remove = []
for idx in range(len(new_tokens)):
if new_tokens[idx] == "." and new_labels[idx] != "O":
new_tokens[idx - 1] += "."
to_remove.append(idx)
new_tokens = [
token for idx, token in enumerate(new_tokens) if idx not in to_remove
]
new_labels = [
label for idx, label in enumerate(new_labels) if idx not in to_remove
]
output = ""
for token, label in zip(new_tokens, new_labels):
if label != "O":
cls = classes[label.split("-")[-1]]
output += '<abbr title="{}"><b>{}</b><t style="color:#ff4000">[{}]</t></abbr> '.format(
cls, token, label
)
else:
output += "{}[{}] ".format(token, label)
output = (
output.replace("[CLS]", "").replace("[O]", "").replace("[SEP]", "").strip()
)
output = output.replace(
"[UNK]",
"""<abbr title="Unknown token"><b style="color:#545454">[UNK]</b></abbr> """,
)
output = "<strong>- Original text -</strong><br><br>{}<br><br><strong>- Analyzed text -</strong><br><br>{}<br><br><mark><strong>Tip:</strong> Hover over the red-underlined words to see its class.<mark>".format(
test_sentence, output
)
return render_template("index.html", analyzed_sentence="{}".format(output))
if __name__ == "__main__":
app.run(host="0.0.0.0", port="5050")