-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_manager.py
96 lines (82 loc) · 3.38 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import re
import os
import numpy as np
import cv2
from config import *
from scipy.misc import imread, imresize, imsave
from random import shuffle
import tensorflow as tf
class DataManager(object):
def __init__(self, dataList,param,shuffle=True):
"""
"""
self.shuffle=shuffle
self.data_list=dataList
self.data_size=len(dataList)
self.data_dir=param["data_dir"]
self.epochs_num=param["epochs_num"]
self.batch_size = param["batch_size"]
self.number_batch = int(np.floor(len(self.data_list) /self.batch_size))
self.next_batch=self.get_next()
def get_next(self):
dataset = tf.data.Dataset.from_generator(self.generator, (tf.float32, tf.int32,tf.int32, tf.string))
dataset = dataset.repeat(self.epochs_num)
if self.shuffle:
dataset = dataset.shuffle(self.batch_size*3+200)
dataset = dataset.batch(self.batch_size)
iterator = dataset.make_one_shot_iterator()
out_batch = iterator.get_next()
return out_batch
def generator(self):
for index in range(len(self.data_list)):
file_basename_image,file_basename_label = self.data_list[index]
image_path = os.path.join(self.data_dir, file_basename_image)
label_path= os.path.join(self.data_dir, file_basename_label)
image= self.read_data(image_path)
label = self.read_data(label_path)
label_pixel,label=self.label_preprocess(label)
image = (np.array(image[:, :, np.newaxis]))
label_pixel = (np.array(label_pixel[:, :, np.newaxis]))
yield image, label_pixel,label, file_basename_image
def read_data(self, data_name):
img = cv2.imread(data_name, 0) # /255.#read the gray image
img = cv2.resize(img, (IMAGE_SIZE[1], IMAGE_SIZE[0]))
# img = img.swapaxes(0, 1)
# image = (np.array(img[:, :, np.newaxis]))
return img
def label_preprocess(self,label):
label = cv2.resize(label, (int(IMAGE_SIZE[1]/8), int(IMAGE_SIZE[0]/8)))
label_pixel=self.ImageBinarization(label)
label=label.sum()
if label>0:
label=1
return label_pixel,label
def ImageBinarization(self,img, threshold=1):
img = np.array(img)
image = np.where(img > threshold, 1, 0)
return image
def label2int(self,label): # label shape (num,len)
# seq_len=[]
target_input = np.ones((MAX_LEN_WORD), dtype=np.float32) + 2 # 初始化为全为PAD
target_out = np.ones(( MAX_LEN_WORD), dtype=np.float32) + 2 # 初始化为全为PAD
target_input[0] = 0 # 第一个为GO
for j in range(len(label)):
target_input[j + 1] = VOCAB[label[j]]
target_out[j] = VOCAB[label[j]]
target_out[len(label)] = 1
return target_input, target_out
def int2label(self,decode_label):
label = []
for i in range(decode_label.shape[0]):
temp = ''
for j in range(decode_label.shape[1]):
if VOC_IND[decode_label[i][j]] == '<EOS>':
break
elif decode_label[i][j] == 3:
continue
else:
temp += VOC_IND[decode_label[i][j]]
label.append(temp)
return label
def get_label(self,f):
return f.split('.')[-2].split('_')[1]