forked from courao/ocr.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathocr.py
executable file
·137 lines (111 loc) · 4.05 KB
/
ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import cv2
from math import *
import numpy as np
from train_code.train_ctpn.ctpn_model_PL import CTPN_Model
from train_code.train_crnn.crnn_recognizer import PytorchOcr
import numpy as np
from pathed import cwd
import pickle as pkl
alphabet_list = pkl.load(
open(cwd / "train_code" / "train_crnn" / "alphabet.pkl", "rb")
)
alphabet = [ord(ch) for ch in alphabet_list]
# load model once
ctpn_model = CTPN_Model()
ctpn_model.load_checkpoint(eval=True)
recognizer = PytorchOcr(alphabet_unicode=alphabet)
def dis(image):
cv2.imshow("image", image)
cv2.waitKey(0)
def sort_box(box):
"""
Sort boxes
"""
box = sorted(box, key=lambda x: sum([x[1], x[3], x[5], x[7]]))
return box
def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4) -> np.array:
"""
turn an image by a number of degrees
return image as numpy array
"""
height, width = img.shape[:2]
heightNew = int(
width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree)))
)
widthNew = int(
height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree)))
)
matRotation = cv2.getRotationMatrix2D((width // 2, height // 2), degree, 1)
matRotation[0, 2] += (widthNew - width) // 2
matRotation[1, 2] += (heightNew - height) // 2
imgRotation = cv2.warpAffine(
img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)
)
pt1 = list(pt1)
pt3 = list(pt3)
[[pt1[0]], [pt1[1]]] = np.dot(matRotation, np.array([[pt1[0]], [pt1[1]], [1]]))
[[pt3[0]], [pt3[1]]] = np.dot(matRotation, np.array([[pt3[0]], [pt3[1]], [1]]))
ydim, xdim = imgRotation.shape[:2]
imgOut = imgRotation[
max(1, int(pt1[1])) : min(ydim - 1, int(pt3[1])),
max(1, int(pt1[0])) : min(xdim - 1, int(pt3[0])),
]
return imgOut
def charRec(img, text_recs, adjust=False) -> dict:
"""
Chop img into text_recs (rectangles) that have text in them
Use CRNN model for character recognition on those rectangles
Returns dict of:
{
1: [<bbox>, <text>]
2: [<bbox>, <text>]
...
}
"""
results = {}
xDim, yDim = img.shape[1], img.shape[0]
# rec: large rectangles with text inside
for index, rec in enumerate(text_recs):
xlength = int((rec[6] - rec[0]) * 0.1)
ylength = int((rec[7] - rec[1]) * 0.2)
if adjust:
pt1 = (max(1, rec[0] - xlength), max(1, rec[1] - ylength))
pt2 = (rec[2], rec[3])
pt3 = (min(rec[6] + xlength, xDim - 2), min(yDim - 2, rec[7] + ylength))
pt4 = (rec[4], rec[5])
else:
pt1 = (max(1, rec[0]), max(1, rec[1]))
pt2 = (rec[2], rec[3])
pt3 = (min(rec[6], xDim - 2), min(yDim - 2, rec[7]))
pt4 = (rec[4], rec[5])
# tilt image if rectangle is slanted
# we want straight rectangles to go into CRNN
degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0]))
partImg = dumpRotateImage(img, degree, pt1, pt2, pt3, pt4)
# Filter out images with x, y dimensions == 0
if (
partImg.shape[0] < 1 # x dimension == 0
or partImg.shape[1] < 1 # y dimension == 0
or partImg.shape[0] > partImg.shape[1] # x-dim > y-dim
):
continue
# Recognize text on those tiny boxes
text = recognizer.recognize(partImg)
if len(text) > 0: # make sure text != ""
results[index] = [rec] + [text]
return results
def ocr(image: np.array):
"""
Detection of text in 3 steps
1) use CTPN to detect large boxes of text
2) sort large boxes, converting to something CRNN would understand
3) use CRNN to detect text in those boxes
4) return images with text
"""
# detect large boxes of text (CTPN)
text_recs, img_framed, image = ctpn_model.get_det_boxes(image)
# sort large boxes, converting to something CRNN would understand
text_recs = sort_box(text_recs)
# detect characters on those large boxes (CRNN)
result = charRec(image, text_recs)
return result, img_framed