Skip to content

Latest commit

 

History

History
42 lines (25 loc) · 2.29 KB

README.md

File metadata and controls

42 lines (25 loc) · 2.29 KB

lilGym Baselines

This repository contains example code for training the baselines of the paper lilGym: Natural Language Visual Reasoning with Reinforcement Learning. Trained models on Zenodo: link.

paper | TL;DR tweet | env code & data | website

Installation

Note: this code has been tested with PyTorch 1.12.1 and CUDA 11.2.

  1. Install lilgym and the dependencies by following the installation instructions.

It also includes the installation of PyTorch.

  1. Clone the current repo.

  2. Install Python dependencies:

cd /path/to/lilgym-baselines
pip install -r requirements.txt

Training

Example of training commands

Training a C3+BERT model with PPO+SF on the TowerScratch environment:

python main.py --env-name TowerScratch-v0 --env-opt tower --learn-opt scratch --algo ppo --stop-forcing  --seed 1 --model c3bert --text-feat bertfix --num-processes 1 --num-steps 2048 --lr 3e-4 --entropy-coef 0.1 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 4000000 --use-gae --optim-type adam --scheduler linear --warmup-percent 0 --log-interval 1 --eval-interval 10 --log-dir ${path} --save-dir ${path} --save-interval 20 --wandb --wandb-run-name name-of-the-run

Training a ViLT model with PPO on the TowerFlipIt environment:

python main.py --env-name TowerFlipIt-v0 --env-opt tower --learn-opt flipit --algo ppo --stop-forcing  --seed 1 --model vilt --num-processes 1 --num-steps 2048 --lr 3e-5 --entropy-coef 0.1 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 4000000 --use-gae --optim-type adamw --scheduler cosine --warmup-percent 0.01 --log-interval 1 --eval-interval 10 --log-dir ${path} --save-dir ${path} --save-interval 20 --wandb --wandb-run-name name-of-the-run

Acknoledgements

The RL code is based on Kostrikov, 2018. We thank the authors for open-sourcing their code.