Skip to content

Commit

Permalink
Fix QRDQN loading target_update_interval (#259)
Browse files Browse the repository at this point in the history
* Fix QRDQN loading target_update_interval

* Update changelog

* Update version

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
jak3122 and araffin authored Oct 2, 2024
1 parent 42595a5 commit 3d9a975
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ To run tests with `pytest`:
make pytest
```

Type checking with `pytype` and `mypy`:
Type checking with `mypy`:

```
make type
```

Codestyle check with `black`, `isort` and `flake8`:
Codestyle check with `black` and `ruff`:

```
make check-codestyle
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.4.0a8 (WIP)
Release 2.4.0a9 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -19,6 +19,7 @@ Bug Fixes:
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Updated QR-DQN paper link in docs (@corentinlger)
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

Deprecations:
^^^^^^^^^^^^^
Expand Down
9 changes: 4 additions & 5 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def _setup_model(self) -> None:
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
# Account for multiple environments
# each call to step() corresponds to n_envs transitions

if self.n_envs > 1:
if self.n_envs > self.target_update_interval:
warnings.warn(
Expand All @@ -164,8 +163,6 @@ def _setup_model(self) -> None:
f"which corresponds to {self.n_envs} steps."
)

self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

def _create_aliases(self) -> None:
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
Expand All @@ -177,7 +174,9 @@ def _on_step(self) -> None:
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a8
2.4.0a9
12 changes: 12 additions & 0 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch as th
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
Expand Down Expand Up @@ -481,3 +482,14 @@ def test_save_load_pytorch_var(tmp_path):
assert model.log_ent_coef is None
# Check that the entropy coefficient is still the same
assert th.allclose(ent_coef_before, ent_coef_after)


def test_dqn_target_update_interval(tmp_path):
# `target_update_interval` should not change when reloading the model. See GH Issue #258.
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
model = QRDQN("MlpPolicy", env, verbose=1, target_update_interval=100)
model.save(tmp_path / "dqn_cartpole")
del model
model = QRDQN.load(tmp_path / "dqn_cartpole")
os.remove(tmp_path / "dqn_cartpole.zip")
assert model.target_update_interval == 100

0 comments on commit 3d9a975

Please sign in to comment.