Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why is ' self.lstm_cell(X, state) ' ? #20

Open
fanyuzeng opened this issue May 7, 2018 · 0 comments
Open

why is ' self.lstm_cell(X, state) ' ? #20

fanyuzeng opened this issue May 7, 2018 · 0 comments

Comments

@fanyuzeng
Copy link

fanyuzeng commented May 7, 2018

I successfully run the program.However, I found threre is something seen abnormal.

class RecurrentController(BaseController):
def network_vars(self):
self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(256)
self.state = self.lstm_cell.zero_state(self.batch_size, tf.float32)

def network_op(self, X, state):
    X = tf.convert_to_tensor(X)
    return self.lstm_cell(X, state)

def get_state(self):
    return self.state

def update_state(self, new_state):
    return tf.no_op()

In the above, tf.nn.rnn_cell.BasicLSTMCell(256) make 256 lstm cells, but the code directly use return self.lstm_cell(X, state) without using tf.static_rnn or tf.nn.dynamic_rnn.
tf.static_rnn or tf.nn.dynamic_rnn can output state, but self.lstm_cell(X, state) can't.
I wonder whether it's wrong and needed to add tf.static_rnn or tf.nn.dynamic_rnn as following:

def network_op(self, X, state):
X = tf.convert_to_tensor(X)
return return tf.nn.dynamic_rnn(self.lstm_cell, X, initial_state=state, time_major=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant