-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathspeech_input.py
152 lines (109 loc) · 4.56 KB
/
speech_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from abc import abstractmethod
import tensorflow as tf
import threading
import numpy as np
class BaseInputLoader:
def __init__(self, input_size):
self.input_size = input_size
def _get_inputs_feed_item(self, input_list):
sequence_lengths = np.array([inp.shape[0] for inp in input_list])
max_time = sequence_lengths.max()
input_tensor = np.zeros((len(input_list), max_time, self.input_size))
for idx, inp in enumerate(input_list):
input_tensor[idx, :inp.shape[0], :] = inp
return input_tensor, sequence_lengths, max_time
@staticmethod
def _get_labels_feed_item(label_list, max_time):
label_shape = np.array([len(label_list), max_time], dtype=np.int)
label_indices = []
label_values = []
for labelIdx, label in enumerate(label_list):
for idIdx, identifier in enumerate(label):
label_indices.append([labelIdx, idIdx])
label_values.append(identifier)
label_indices = np.array(label_indices, dtype=np.int)
label_values = np.array(label_values, dtype=np.int)
return tf.SparseTensorValue(label_indices, label_values, label_shape)
@abstractmethod
def get_inputs(self):
raise NotImplementedError()
def get_feed_dict(self):
return None
class SingleInputLoader(BaseInputLoader):
"""
Manually feeds single inputs using the feed dictionary
"""
def __init__(self, input_size):
super().__init__(input_size)
self.speech_input = None
batch_size = 1
with tf.device("/cpu:0"):
# inputs is of dimension [batch size, max time, input size]
self.inputs = tf.placeholder(tf.float32, [batch_size, None, input_size], name='inputs')
self.sequence_lengths = tf.placeholder(tf.int32, [batch_size], name='sequence_lengths')
def get_inputs(self):
# returns tensors for inputs and sequence lengths
return self.inputs, self.sequence_lengths, None
def get_feed_dict(self):
"""
returns the feed dictionary for the next model step
"""
if self.speech_input is None:
raise ValueError('Speech input must be provided using `set_input` first!')
input_tensor, sequence_lengths, max_time = self._get_inputs_feed_item([self.speech_input])
self.speech_input = None
return {
self.inputs: input_tensor,
self.sequence_lengths: sequence_lengths
}
def set_input(self, speech_input):
self.speech_input = speech_input
class InputBatchLoader(BaseInputLoader):
def __init__(self, input_size, batch_size, data_generator_creator, max_steps=None):
super().__init__(input_size)
self.batch_size = batch_size
self.data_generator_creator = data_generator_creator
self.steps_left = max_steps
with tf.device("/cpu:0"):
self.inputs = tf.placeholder(tf.float32, [batch_size, None, input_size], name='inputs')
self.sequence_lengths = tf.placeholder(tf.int32, [batch_size], name='sequence_lengths')
self.labels = tf.sparse_placeholder(tf.int32, name='labels')
self.queue = tf.FIFOQueue(dtypes=[tf.float32, tf.int32, tf.string], capacity=100)
serialized_labels = tf.serialize_many_sparse(self.labels)
self.enqueue_op = self.queue.enqueue([self.inputs, self.sequence_lengths, serialized_labels])
def get_inputs(self):
with tf.device("/cpu:0"):
inputs, sequence_lengths, labels = self.queue.dequeue()
labels = tf.deserialize_many_sparse(labels, dtype=tf.int32)
return inputs, sequence_lengths, labels
def _batch(self, iterable):
args = [iter(iterable)] * self.batch_size
return zip(*args)
def _enqueue(self, sess, coord):
data_generator = self.data_generator_creator()
for sample_batch in self._batch(data_generator):
input_list, label_list = zip(*sample_batch)
input_tensor, sequence_lengths, max_time = self._get_inputs_feed_item(input_list)
labels = self._get_labels_feed_item(label_list, max_time)
sess.run(self.enqueue_op, feed_dict={
self.inputs: input_tensor,
self.sequence_lengths: sequence_lengths,
self.labels: labels
})
if self.steps_left is not None:
self.steps_left -= 1
if self.steps_left == 0:
break
if coord.should_stop():
break
sess.run(self.queue.close())
def start_threads(self, sess, coord, n_threads=1):
# starts the background threads to fill the queue
threads = []
for n in range(n_threads):
t = threading.Thread(target=self._enqueue, args=(sess, coord))
t.daemon = True #thread closes when parent quits
t.start()
coord.register_thread(t)
threads.append(t)
return threads