Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.
- 2-3 times faster than previous Torch and Tensorflow implementations
- Optimized using JAX's Just-In-Time (JIT) compilation
pip install -r requirement.txt
pip install .
- ✔️ : Optional implemented
- ✅ : Defualt implemented at papers
- ❌ : Not implemeted yet or can not implemented
| Name |
Q-Net based |
Actor-Critic based |
DPG based |
| Gymnasium |
✔️ |
✔️ |
✔️ |
| EnvPool |
✔️ |
✔️ |
✔️ |
| Name |
Double1 |
Dueling2 |
Per3 |
N-step45 |
NoisyNet6 |
Munchausen7 |
Ape-X8 |
HL-Gauss9 |
| DQN10 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
| C5111 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
| QRDQN12 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
| IQN13 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
❌ |
| FQF14 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
❌ |
| SPR15 |
✅ |
✅ |
✅ |
✅ |
✅ |
✔️ |
❌ |
✔️ |
| BBF16 |
✅ |
✅ |
✅ |
✅ |
✔️ |
✔️ |
❌ |
✔️ |
| Name |
Box |
Discrete |
IMPALA17 |
| A2C18 |
✔️ |
✔️ |
✔️ |
| PPO19 |
✔️ |
✔️ |
✔️20 |
| Truly PPO(TPPO)21 |
✔️ |
✔️ |
❌ |
| SPO22 |
✔️ |
✔️ |
❌ |
| Name |
Per3 |
N-step45 |
Ape-X8 |
Simba23 |
Simba-v224 |
| DDPG25 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
| TD326 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
| SAC27 |
✔️ |
✔️ |
❌ |
✔️ |
✔️ |
| DAC28❌ |
❌ |
❌ |
❌ |
❌ |
❌ |
| TQC29 |
✔️ |
✔️ |
❌ |
✔️ |
✔️ |
| TD730 |
✅(LAP31) |
❌ |
❌ |
✔️ |
✔️ |
| CrossQ32 |
✔️ |
✔️ |
❌ |
✔️ |
✔️ |
| BRO33❌ |
❌ |
❌ |
❌ |
❌ |
❌ |
To test Atari with DQN (or C51, QRDQN, IQN, FQF):
python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards
500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec).
Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.
score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]