-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
230 lines (199 loc) · 6.99 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python
##
## ocr.py
##
## Made by xsyann
## Contact <[email protected]>
##
## Started on Fri Mar 28 15:09:44 2014 xsyann
## Last update Thu Apr 3 21:02:45 2014 xsyann
##
"""
Dataset pre-processing for OCR using OpenCV.
Authors:
Nicolas PELICAN
Yann KOETH
"""
import os
import string
import random
import numpy as np
import cv2
class Dataset(object):
"""Dataset of image files.
Generate samples and responses arrays from images.
"""
def __init__(self, folders):
"""Init the dataset from a folder dict of form
{'label1': 'path/item1', 'label2': 'path/label2', ...}
"""
self.testSamples = None
self.testResponses = None
self.trainSamples = None
self.trainResponses = None
self.maxPerClass = 170
self.__folders = folders
self.__classifications = self.__loadClassifications()
def preprocess(self, trainRatio):
self.__stackArrays(self.__getItems(trainRatio))
def getResponse(self, index):
return self.__classifications[index]
@property
def classificationCount(self):
return len(self.__classifications)
@property
def trainSampleCount(self):
if self.trainSamples.size == 0:
return 0
trainSampleCount, size = self.trainSamples.shape
return trainSampleCount
@property
def testSampleCount(self):
if self.testSamples.size == 0:
return 0
testSampleCount, size = self.testSamples.shape
return testSampleCount
def __loadClassifications(self):
classifications = []
for label, folder in self.__folders.iteritems():
if not label in classifications:
classifications.append(label)
return classifications
def __stackArraysAux(self, items):
"""Create samples and responses arrays.
"""
if not items:
return (np.array([]), np.array([]))
samples = []
responses = []
nClass = self.classificationCount
for item in items:
responses.append(self.__classifications.index(item.classification))
samples.append(item.sample)
return (np.vstack(samples), np.array(responses))
def __stackArrays(self, items):
trainItems, testItems = items
self.trainSamples, self.trainResponses = self.__stackArraysAux(trainItems)
self.testSamples, self.testResponses = self.__stackArraysAux(testItems)
def __getItems(self, trainRatio):
"""Create dataset items.
"""
trainItems = []
testItems = []
for label, folder in self.__folders.iteritems():
images = self.__getImages(folder)
random.shuffle(images)
currentClassItems = []
for i, image in enumerate(images):
if i >= self.maxPerClass:
break
item = DatasetItem()
item.loadFromFile(image)
item.classification = label
currentClassItems.append(item)
trainCount = int(np.ceil(min(len(images), self.maxPerClass) * trainRatio))
trainItems.extend(currentClassItems[:trainCount])
testItems.extend(currentClassItems[trainCount:])
return (trainItems, testItems)
def __getImages(self, folder):
"""Returns a list of all images in folder.
"""
imgExt = [".bmp", ".png"]
images = []
if os.path.isdir(folder):
for file in os.listdir(folder):
filename, ext = os.path.splitext(file)
if ext.lower() in imgExt:
images.append(os.path.join(folder, file))
return images
class DatasetItem(object):
"""An item in the data set.
Handle pre-processing of that item.
"""
RESIZE = 16
def __init__(self):
self.input = None
self.preprocessed = None
self.classification = None
def loadFromFile(self, filename):
if not os.path.isfile(filename):
raise OSError(2, 'File not found', filename)
self.__load(filename)
self.__preprocess()
def loadFromImage(self, img):
self.input = img
self.__preprocess()
@property
def sample(self):
sample = np.array(self.preprocessed)
return sample.ravel().astype(np.float32)
def __load(self, filename):
self.input = cv2.imread(filename, cv2.CV_LOAD_IMAGE_COLOR)
def __mergeContours(self, contours):
"""Merge all bounding boxes.
Returns x, y, w, h.
"""
x, y, x1, y1 = [], [], [], []
for cnt in contours:
pX, pY, pW, pH = cv2.boundingRect(cnt)
x.append(pX)
y.append(pY)
x1.append(pX + pW)
y1.append(pY + pH)
bbX, bbY = min(x), min(y)
bbW, bbH = max(x1) - bbX, max(y1) - bbY
return bbX, bbY, bbW, bbH
def __cropToFit(self, image):
"""Crop image to fit the bounding box.
"""
clone = image.copy()
contours, hierarchy = cv2.findContours(clone, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image
x, y, w, h = self.__mergeContours(contours)
cv2.rectangle(self.input, (x, y), (x + w, y + h), (0, 0, 255), 1)
return image[y:y+h, x:x+w]
def __ratioResize(self, image):
"""Resize image to get an aspect ratio of 1:1 (square).
"""
h, w = image.shape
ratioSize = max(h, w)
blank = np.zeros((ratioSize, ratioSize), np.uint8)
x = (ratioSize - w) / 2.0
y = (ratioSize - h ) / 2.0
blank[y:y+h, x:x+w] = image
return blank
def __preprocess(self):
"""Pre-process image :
- Convert To Grayscale
- Gaussian Blur (remove noise)
- Threshold (black and white image)
- Crop to fit bounding box
- Resize
"""
gray = cv2.cvtColor(self.input, cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(src=gray, ksize=(5, 5), sigmaX=0)
thresh = cv2.adaptiveThreshold(src=blur, maxValue=255,
adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
thresholdType=cv2.THRESH_BINARY_INV,
blockSize=11, C=2)
cropped = self.__cropToFit(thresh)
squared = self.__ratioResize(cropped)
self.preprocessed = cv2.resize(squared, (self.RESIZE, self.RESIZE))
if __name__ == "__main__":
import argparse, sys
parser = argparse.ArgumentParser(description="Show the pre-processing step of OCR.")
parser.add_argument("filename", help="File to pre-process")
args = parser.parse_args()
item = DatasetItem()
try:
item.loadFromFile(args.filename)
except (OSError, cv2.error) as err:
print err
sys.exit(1)
print __doc__
cv2.imshow("Input", item.input)
cv2.imshow("Pre-processed", item.preprocessed)
cv2.moveWindow("Input", 200, 0)
cv2.waitKey(0)
cv2.destroyAllWindows()