-
Notifications
You must be signed in to change notification settings - Fork 173
/
captcha_records.py
123 lines (103 loc) · 3.54 KB
/
captcha_records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import config
IMAGE_HEIGHT = config.IMAGE_HEIGHT
IMAGE_WIDTH = config.IMAGE_WIDTH
CHAR_SETS = config.CHAR_SETS
CLASSES_NUM = config.CLASSES_NUM
CHARS_NUM = config.CHARS_NUM
RECORD_DIR = config.RECORD_DIR
TRAIN_FILE = config.TRAIN_FILE
VALID_FILE = config.VALID_FILE
FLAGS = None
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def label_to_one_hot(label):
one_hot_label = np.zeros([CHARS_NUM, CLASSES_NUM])
offset = []
index = []
for i, c in enumerate(label):
offset.append(i)
index.append(CHAR_SETS.index(c))
one_hot_index = [offset, index]
one_hot_label[one_hot_index] = 1.0
return one_hot_label.astype(np.uint8)
def conver_to_tfrecords(data_set, name):
"""Converts a dataset to tfrecords."""
if not os.path.exists(RECORD_DIR):
os.makedirs(RECORD_DIR)
filename = os.path.join(RECORD_DIR, name)
print('>> Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
num_examples = len(data_set)
for index in range(num_examples):
image = data_set[index][0]
height = image.shape[0]
width = image.shape[1]
image_raw = image.tostring()
label = data_set[index][1]
label_raw = label_to_one_hot(label).tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'label_raw': _bytes_feature(label_raw),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
print('>> Writing Done!')
def create_data_list(image_dir):
if not gfile.Exists(image_dir):
print("Image director '" + image_dir + "' not found.")
return None
extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
print("Looking for images in '" + image_dir + "'")
file_list = []
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
print("No files found in '" + image_dir + "'")
return None
images = []
labels = []
for file_name in file_list:
image = Image.open(file_name)
image_gray = image.convert('L')
image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
input_img = np.array(image_resize, dtype='int16')
image.close()
label_name = os.path.basename(file_name).split('_')[0]
images.append(input_img)
labels.append(label_name)
return zip(images, labels)
def main(_):
training_data = create_data_list(FLAGS.train_dir)
conver_to_tfrecords(training_data, TRAIN_FILE)
validation_data = create_data_list(FLAGS.valid_dir)
conver_to_tfrecords(validation_data, VALID_FILE)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--train_dir',
type=str,
default='./data/train_data',
help='Directory training to get captcha data files and write the converted result.'
)
parser.add_argument(
'--valid_dir',
type=str,
default='./data/valid_data',
help='Directory validation to get captcha data files and write the converted result.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)