forked from SeraFxy/ziti
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
83 lines (62 loc) · 2.29 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
from scipy.misc import imresize
import matplotlib.pyplot as plt
import tensorflow as tf
from dataset.generator import read_text, draw_font
from nnets.vgg import vgg
def generator_images(texts, fonts):
for text in texts:
for font in fonts:
image, _ = draw_font(text, font, mode='test')
yield image, text
def show_errors(error_infos, fonts):
length = len(error_infos)
labels = len(fonts)
print(fonts)
for i in range(length):
text, pred = error_infos[i]
'''pred是错误的结果列表'''
pred = [round(a, 5) for a in pred]
index = pred.index(max(pred))
for j in range(labels):
axis = plt.subplot(length, labels, i * labels + j + 1)
axis.axis('off')
font = fonts[j]
image, _ = draw_font(text, font, mode='test')
if index == j:
plt.title(str(pred))
plt.imshow(image)
plt.show()
def run():
file_name = u'test.txt'
# file_name = u'dataset/中国汉字大全.txt'
texts = read_text(file_name)
fonts_dir = os.path.join('dataset', 'fonts')
fonts = [os.path.join(os.getcwd(), fonts_dir, path) for path in os.listdir(fonts_dir)]
images_gen = generator_images(texts, fonts)
inputs = tf.placeholder(tf.float32, shape = [None, None, 3])
example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
example = tf.image.per_image_standardization(example)
example = tf.expand_dims(example, 0)
outputs = vgg(example, 5, 1.0)
sess = tf.Session()
restorer = tf.train.Saver()
restorer.restore(sess, 'models/vgg.ckpt')
error = 0
error_texts = []
for index, info in enumerate(images_gen):
image, text = info
image = np.asarray(image)
pred = sess.run(outputs, feed_dict={inputs:image})
pred = np.squeeze(pred)
label = np.squeeze(np.where(pred==np.max(pred)))
if index%5 != label:
error_texts.append((text, pred.tolist()))
error += 1
print('test num: {}, error num: {}, acc: {}'.format(index + 1, error, 1 - float(error) / index))
show_errors(error_texts, fonts)
if __name__ == '__main__':
run()