Skip to content

Commit

Permalink
Implemented CrossQ (#243)
Browse files Browse the repository at this point in the history
* Implemented CrossQ

* Fixed code style

* Clean up, comments and refactored to sbx variable names

* 1024 neuron Q function (sbx default)

* batch norm parameters as function arguments

* clean up. reshape instead of split

* Added policy delay

* fixed commit-checks

* Fix f-string

* Update documentation

* Rename to torch layers

* Fix for policy delay and minor edits

* Update tests

* Update documentation

* Update doc

* Add more tests for crossQ

* Improve doc and expose batchnorm params

* Add some comments and todos and fix type check

* Use torch module for BN

* Re-organize losses

* Add set_bn_training_mode

* Simplify network creation with new SB3 version, and fix default momentum

* Use different b1 for Adam as in original implementation

* Reformat TOML file

* Update CI workflow, skip mypy for 3.8

* Update CrossQ doc

* Use uv to download packages on github CI

* System install for Github CI

* Fix for pytorch install

* Use +cpu version

* Pytorch 2.5.0 doesn't support python 3.8

* Update comments

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
danielpalen and araffin authored Oct 24, 2024
1 parent 3d9a975 commit 68828f3
Show file tree
Hide file tree
Showing 20 changed files with 1,221 additions and 28 deletions.
15 changes: 10 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install master version
# and dependencies for docs and tests
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
pip install .
uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
uv pip install --system .
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
Expand All @@ -58,6 +61,8 @@ jobs:
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See documentation for the full list of included features.
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
7 changes: 7 additions & 0 deletions docs/common/torch_layers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _th_layers:

Torch Layers
============

.. automodule:: sb3_contrib.common.torch_layers
:members:
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Pr
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
CrossQ ✔️ ❌ ❌ ❌ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ❌ ❌ ❌ ✔️
Expand Down
23 changes: 23 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,26 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
vec_env.render("human")
CrossQ
------

Train a CrossQ agent on the Pendulum environment.

.. code-block:: python
from sb3_contrib import CrossQ
model = CrossQ(
"MlpPolicy",
"Pendulum-v1",
verbose=1,
policy_kwargs=dict(
net_arch=dict(
pi=[256, 256],
qf=[1024, 1024],
)
),
)
model.learn(total_timesteps=5_000, log_interval=4)
model.save("crossq_pendulum")
Binary file added docs/images/crossQ_performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/crossq
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand All @@ -42,6 +43,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:maxdepth: 1
:caption: Common

common/torch_layers
common/utils
common/wrappers

Expand Down
10 changes: 7 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
Changelog
==========


Release 2.4.0a9 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

**New algorithm: added CrossQ**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.4.0

New Features:
^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)

Bug Fixes:
^^^^^^^^^^
Expand All @@ -28,6 +31,7 @@ Others:
^^^^^^^
- Updated PyTorch version on CI to 2.3.1
- Remove unnecessary SDE noise resampling in PPO/TRPO update
- Switched to uv to download packages on GitHub CI

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -584,4 +588,4 @@ Contributors:
-------------

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @corentinlger
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger
134 changes: 134 additions & 0 deletions docs/modules/crossq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
.. _crossq:

.. automodule:: sb3_contrib.crossq


CrossQ
======

Implementation of CrossQ proposed in:

`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.`

CrossQ is an algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms.
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks.
This results in a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy

.. note::

Compared to the original implementation, the default network architecture for the q-value function is ``[1024, 1024]``
instead of ``[2048, 2048]`` as it provides a good compromise between speed and performance.

.. note::

There is currently no ``CnnPolicy`` for using CrossQ with images. We welcome help from contributors to add this feature.


Notes
-----

- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
- Original Implementation: https://github.com/adityab/CrossQ
- SBX (SB3 Jax) Implementation: https://github.com/araffin/sbx


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ❌
============= ====== ===========


Example
-------

.. code-block:: python
from sb3_contrib import CrossQ
model = CrossQ("MlpPolicy", "Walker2d-v4")
model.learn(total_timesteps=1_000_000)
model.save("crossq_walker")
Results
-------

Performance evaluation of CrossQ on six MuJoCo environments, see `PR #243 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/243>`_.
Compared to results from the original paper as well as a version from `SBX <https://github.com/araffin/sbx>`_.

.. image:: ../images/crossQ_performance.png


Open RL benchmark report: https://wandb.ai/openrlbenchmark/sb3-contrib/reports/SB3-Contrib-CrossQ--Vmlldzo4NTE2MTEx


How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone RL-Zoo:

.. code-block:: bash
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

.. code-block:: bash
python train.py --algo crossq --env $ENV_ID --n-eval-envs 5 --eval-episodes 20 --eval-freq 25000
Plot the results:

.. code-block:: bash
python scripts/all_plots.py -a crossq -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/crossq_results
python scripts/plot_from_file.py -i logs/crossq_results.pkl -latex -l CrossQ
Comments
--------

This implementation is based on SB3 SAC implementation.


Parameters
----------

.. autoclass:: CrossQ
:members:
:inherited-members:

.. _crossq_policies:

CrossQ Policies
---------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy
:members:
:noindex:
1 change: 0 additions & 1 deletion docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ Clone the repo for the experiment:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo
git checkout feat/recurrent-ppo
Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ exclude_lines = [
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]

# [tool.pyright]
# extraPaths = ["../torchy-baselines/"]
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
Loading

0 comments on commit 68828f3

Please sign in to comment.