-
-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ConvLSTM2D instead of Conv2D layer for better performance? #6
Comments
Hi, Thanks for the suggestion! Yes it could, possibly, who knows… So far, all my experiments have been huge failures, with up-to 4 days of training leading to zero improvement over the random policy. |
I got an error while trying to replace Conv2D with ConvLSTM2D without any other modification, and I have the feeling that it won't be trivial to do the change. FYI: Traceback (most recent call last):
File "run-mario.py", line 204, in <module>
mario_main(N=PARALLEL_EMULATORS)
File "run-mario.py", line 127, in mario_main
name=dqn_model_name
File "/home/lilian_besson/gym-nes-mario-bros.git/src/dqn/model.py", line 164, in __init__
self.base_model = q_model(input_shape, num_actions)
File "/home/lilian_besson/gym-nes-mario-bros.git/src/dqn/model.py", line 107, in q_model
inputs, outputs = q_function(input_shape, num_actions)
File "/home/lilian_besson/gym-nes-mario-bros.git/src/dqn/model.py", line 80, in q_function
out = ConvLSTM2D(filters=32, kernel_size=8, strides=(4, 4), padding=padding, activation='relu')(image_input)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/convolutional_recurrent.py", line 277, in __call__
return super(ConvRNN2D, self).__call__(inputs, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/recurrent.py", line 499, in __call__
return super(RNN, self).__call__(inputs, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 575, in __call__
self.assert_input_compatibility(inputs)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 474, in assert_input_compatibility
str(K.ndim(x)))
ValueError: Input 0 is incompatible with layer conv_lst_m2d_1: expected ndim=5, found ndim=4 |
The LSTM algorithm learns patterns from data across time, this means the LSTM aims to understand the relationship between data from one instance in time and another instance, while still learning relationship between inputs and output. To achieve this, the LSTM needs a time step which is bassically grouping rows of your dataset and feeding to the neural network. So you will need to reshape your dataset which current is in 4 dimensions to 5 dimensons which is what the last line of the error means, the ValueError Assuming your dataset shape is as follows:
=> (1023, 6, 4, 28, 28) You may also employ the numpy reshape method but I prefer it this way |
Hi @Nindaime, |
I've been very interested in this project, playing with it locally.
I think it would be helpful to make at least one of the layers (perhaps all of the Conv2D ones) Convolutional 2D LSTM layers. I believe this will help find temporal-spacial relationships within the game screen better and result in better network models.
The text was updated successfully, but these errors were encountered: