Skip to content

Commit

Permalink
Add rollout_buffer_class to TRPO (#214)
Browse files Browse the repository at this point in the history
* Add rollout_buffer_class and rollout_buffer_kwargs to TRPO

* Update requirements and changelog

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
ernestum and araffin authored Oct 30, 2023
1 parent 301a8b3 commit b5e6518
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.2.0a8 (WIP)
Release 2.2.0a9 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -16,6 +16,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``set_options`` for ``AsyncEval``
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to TRPO

Bug Fixes:
^^^^^^^^^^
Expand Down
7 changes: 7 additions & 0 deletions sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.distributions import kl_divergence
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy
Expand Down Expand Up @@ -53,6 +54,8 @@ class TRPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param normalize_advantage: Whether to normalize or not the advantage
:param target_kl: Target Kullback-Leibler divergence between updates.
Should be small for stability. Values like 0.01, 0.05.
Expand Down Expand Up @@ -91,6 +94,8 @@ def __init__(
gae_lambda: float = 0.95,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = True,
target_kl: float = 0.01,
sub_sampling_factor: int = 1,
Expand All @@ -114,6 +119,8 @@ def __init__(
max_grad_norm=0.0,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a8
2.2.0a9
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.2.0a8,<3.0",
"stable_baselines3>=2.2.0a9,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down

0 comments on commit b5e6518

Please sign in to comment.