Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if filter width =3, how to do fast inference? #19

Open
weixsong opened this issue Dec 6, 2017 · 2 comments
Open

if filter width =3, how to do fast inference? #19

weixsong opened this issue Dec 6, 2017 · 2 comments

Comments

@weixsong
Copy link

weixsong commented Dec 6, 2017

In the new paper, Google use filter width =3 to increase the receptive field.

Then how could we do inference with filter width 3?
My idea is use to Queue, because the dilation is still 2 times increased, the first Queue is used to store the first half of middle value, and the second Queue is used to store the second half middle value.
Output of first Queue then be enqueued into the second Queue.

such as:

        current_state = q.dequeue()
        push = q.enqueue([current_layer])
        init_ops.append(init)
        push_ops.append(push)

        pre_state = None
        if self.filter_width == 3:
            q2 = tf.FIFOQueue(
                 1,
                 dtypes=tf.float32,
                 shapes=(self.batch_size, self.quantization_channels))

            init2 = q2.enqueue_many(tf.zeros((1, self.batch_size, self.quantization_channels)))

            pre_state = q2.dequeue()
            push2 = q2.enqueue([current_state])

            init_ops2.append(init2)
            push_ops2.append(push2)

        if self.filter_width == 2:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state)
        if self.filter_width == 3:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state, pre_state)
 


...
        with tf.name_scope('dilated_stack'):
            for layer_index, dilation in enumerate(self.dilations):
                with tf.name_scope('layer{}'.format(layer_index)):

                    q = tf.FIFOQueue(
                        dilation,
                        dtypes=tf.float32,
                        shapes=(self.batch_size, self.residual_channels))
                    init = q.enqueue_many(
                        tf.zeros((dilation, self.batch_size,
                                  self.residual_channels)))

                    current_state = q.dequeue()
                    push = q.enqueue([current_layer])
                    init_ops.append(init)
                    push_ops.append(push)

                    pre_state = None
                    if self.filter_width == 3:
                        q2 = tf.FIFOQueue(
                             dilation,
                             dtypes=tf.float32,
                             shapes=(self.batch_size, self.residual_channels))

                        init2 = q2.enqueue_many(tf.zeros((dilation, self.batch_size, self.residual_channels)))

                        pre_state = q2.dequeue()
                        push2 = q2.enqueue([current_state])

                        init_ops2.append(init2)
                        push_ops2.append(push2)

                    output, current_layer = self._generator_dilation_layer(
                        current_layer, current_state, layer_index, dilation,
                        global_condition_batch, local_condition, pre_state)
                    outputs.append(output)

is that make sense?

@KingStorm
Copy link

i think that's the idea.

@twidddj
Copy link

twidddj commented Mar 30, 2018

@weixsong Hi, We have considered this issue. You can find the method in our repository

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants