The concept is as follows: the player maneuvers a line which grows in length, with the line itself being a primary obstacle. The line resembles a moving snake that becomes longer as it eats food. The player loses when the snake runs into the screen border or itself. The goal is to get as long as possible before dying. In this project, we introduce our different reinforcement learning approaches in order to efficiently train an agent to play Snake.
Install all required packages:
pip3 install -r requirements.txtIn main.py choose the hyper-parameters, the model and run:
python3 main.pySnake is learning!
4 RL models are currently implemented in the framework. 3 classic models and 1
deep learning model. All models inherit from the ABCTrainer class which
contains all the abstract methods needed to create a new model
(choose_action, update_hyperparameters...). Check in the trainer folder
what the hyperparameters needed are for each models.
- For classic models: in
mainimport a pretrainedq_tableby using theimport_q_tablefunction to instanciate the trainer. - For
DQLTrainer: use theload_model_pathattribute ofDQLTrainer, we providepytorchpretrained weights that you can use to save you hours of training before getting results. - Be careful to put
epsilon_init = 0when using a pretained model.
epsilon_init = 0
DQLTrainer(epsilon_init, epsilon_decay, learning_rate, gamma, decay_rate,
size_x, size_y, load_model_path="pretrained_models/trained_weights.pth")It is also possible to stop the training of your model, by using the test
argument of the trainers' iterate method.
grid_search.py allows you to test a wide range of hyper-parameters on the
model of your choice, it returns a json file with the results.
More details can be found in our paper Snake-RL.pdf.