-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi GPU (WIP) #169
base: master
Are you sure you want to change the base?
Multi GPU (WIP) #169
Changes from all commits
93e5e9e
9a329b6
1f00d7b
a210c58
19d11a0
fad20de
65c3f9e
ce0994b
60cbd0d
f6ea81e
44f92f7
2592516
46d44db
e835ca9
470d7b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,21 +3,36 @@ | |
from .ops import causal_conv, mu_law_encode | ||
|
||
|
||
def create_variable(name, shape): | ||
def _create_variable(name, shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the underscore was for member functions (methods) of classes. Am I wrong? |
||
'''Create a convolution filter variable with the specified name and shape, | ||
and initialize it using Xavier initialition.''' | ||
initializer = tf.contrib.layers.xavier_initializer_conv2d() | ||
variable = tf.Variable(initializer(shape=shape), name=name) | ||
return variable | ||
|
||
|
||
def create_bias_variable(name, shape): | ||
def _create_bias_variable(name, shape): | ||
'''Create a bias variable with the specified name and shape and initialize | ||
it to zero.''' | ||
initializer = tf.constant_initializer(value=0.0, dtype=tf.float32) | ||
return tf.Variable(initializer(shape=shape), name) | ||
|
||
|
||
def _get_variable(name, shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the create_embedding_table() function should be changed to tf.get_variable() too? I'm not quite sure, if the embedding table is also trained as model parameters, should it be made shared among GPU? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weixsong Yes. This PR is out-dated. :) |
||
'''Create a convolution filter variable with the specified name and shape, | ||
and initialize it using Xavier initialition.''' | ||
initializer = tf.contrib.layers.xavier_initializer_conv2d() | ||
variable = tf.get_variable(name, initializer=initializer(shape=shape)) | ||
return variable | ||
|
||
|
||
def _get_bias_variable(name, shape): | ||
'''Create a bias variable with the specified name and shape and initialize | ||
it to zero.''' | ||
initializer = tf.constant_initializer(value=0.0, dtype=tf.float32) | ||
return tf.get_variable(name, initializer=initializer(shape=shape)) | ||
|
||
|
||
class WaveNetModel(object): | ||
'''Implements the WaveNet network for generative audio. | ||
|
||
|
@@ -43,6 +58,7 @@ def __init__(self, | |
quantization_channels=2**8, | ||
use_biases=False, | ||
scalar_input=False, | ||
reuse_variables=False, | ||
initial_filter_width=32, | ||
histograms=False): | ||
'''Initializes the WaveNet model. | ||
|
@@ -83,6 +99,7 @@ def __init__(self, | |
self.scalar_input = scalar_input | ||
self.initial_filter_width = initial_filter_width | ||
self.histograms = histograms | ||
self.reuse_variables = reuse_variables | ||
|
||
self.variables = self._create_variables() | ||
|
||
|
@@ -93,6 +110,13 @@ def _create_variables(self): | |
|
||
var = dict() | ||
|
||
if self.reuse_variables: | ||
create_variable = _get_variable | ||
create_bias_variable = _get_bias_variable | ||
else: | ||
create_variable = _create_variable | ||
create_bias_variable = _create_bias_variable | ||
|
||
with tf.variable_scope('wavenet'): | ||
with tf.variable_scope('causal_layer'): | ||
layer = dict() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete commented-out line