You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I was trying to understand the LSTMCell of Flax. The documentation for the __call__ function says:
carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.
It thought it was weird that the cell state wasn't returned in addition with the hidden state. But in the source code, initialize_carry seems to return a tuple containing the cell and the hidden states :
return (c, h)
Additionally, the __call__ function seems to also returns both the cell state and the hidden state in the carry:
return (new_c, new_h), new_h
Did I misunderstand something? If not, should the documentation be updated to clarify that the carry includes both the cell state and the hidden state?
Anyway, thanks for the great library!
The text was updated successfully, but these errors were encountered:
Hello, I was trying to understand the LSTMCell of Flax. The documentation for the
__call__
function says:It thought it was weird that the cell state wasn't returned in addition with the hidden state. But in the source code,
initialize_carry
seems to return a tuple containing the cell and the hidden states :Additionally, the
__call__
function seems to also returns both the cell state and the hidden state in the carry:Did I misunderstand something? If not, should the documentation be updated to clarify that the carry includes both the cell state and the hidden state?
Anyway, thanks for the great library!
The text was updated successfully, but these errors were encountered: