-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[RLlib] Add TQC (Truncated Quantile Critics) algorithm implementation #59808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,432
−0
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
dc2c31e
[RLlib] Add TQC (Truncated Quantile Critics) algorithm implementation
tk42 0c22910
[RLlib] Refactor TQC implementation - remove unused code and improve …
tk42 148c1cd
[RLlib] Remove unused imports from TQC implementation
tk42 b9540a0
[RLlib] Code formatting and cleanup for TQC implementation
tk42 15c7a9c
[RLlib] Add TQC example scripts and refactor module implementation
tk42 e536c3e
[RLlib] Refactor TQC torch module to inherit from DefaultTQCRLModule …
tk42 33db159
[RLlib] Add validation for top_quantiles_to_drop_per_net parameter in…
tk42 20c2130
Merge branch 'master' into master
simonsays1980 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # TQC (Truncated Quantile Critics) | ||
|
|
||
| ## Overview | ||
|
|
||
| TQC is an extension of SAC (Soft Actor-Critic) that uses distributional reinforcement learning with quantile regression to control overestimation bias in the Q-function. | ||
|
|
||
| **Paper**: [Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics](https://arxiv.org/abs/2005.04269) | ||
|
|
||
| ## Key Features | ||
|
|
||
| - **Distributional Critics**: Each critic network outputs multiple quantiles instead of a single Q-value | ||
| - **Multiple Critics**: Uses `n_critics` independent critic networks (default: 2) | ||
| - **Truncated Targets**: Drops the top quantiles when computing target Q-values to reduce overestimation | ||
| - **Quantile Huber Loss**: Uses quantile regression with Huber loss for critic training | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from ray.rllib.algorithms.tqc import TQCConfig | ||
|
|
||
| config = ( | ||
| TQCConfig() | ||
| .environment("Pendulum-v1") | ||
| .training( | ||
| n_quantiles=25, # Number of quantiles per critic | ||
| n_critics=2, # Number of critic networks | ||
| top_quantiles_to_drop_per_net=2, # Quantiles to drop for bias control | ||
| ) | ||
| ) | ||
|
|
||
| algo = config.build() | ||
| for _ in range(100): | ||
| result = algo.train() | ||
| print(f"Episode reward mean: {result['env_runners']['episode_reward_mean']}") | ||
| ``` | ||
|
|
||
| ## Configuration | ||
|
|
||
| ### TQC-Specific Parameters | ||
|
|
||
| | Parameter | Default | Description | | ||
| | ------------------------------- | ------- | ------------------------------------------------------------------ | | ||
| | `n_quantiles` | 25 | Number of quantiles for each critic network | | ||
| | `n_critics` | 2 | Number of critic networks | | ||
| | `top_quantiles_to_drop_per_net` | 2 | Number of top quantiles to drop per network when computing targets | | ||
|
|
||
| ### Inherited from SAC | ||
|
|
||
| TQC inherits all SAC parameters including: | ||
|
|
||
| - `actor_lr`, `critic_lr`, `alpha_lr`: Learning rates | ||
| - `tau`: Target network update coefficient | ||
| - `initial_alpha`: Initial entropy coefficient | ||
| - `target_entropy`: Target entropy for automatic alpha tuning | ||
|
|
||
| ## Algorithm Details | ||
|
|
||
| ### Critic Update | ||
|
|
||
| 1. Each critic outputs `n_quantiles` quantile estimates | ||
| 2. For target computation: | ||
| - Collect all quantiles from all critics: `n_critics * n_quantiles` values | ||
| - Sort all quantiles | ||
| - Drop the top `top_quantiles_to_drop_per_net * n_critics` quantiles | ||
| - Use remaining quantiles as targets | ||
| 3. Train critics using quantile Huber loss | ||
|
|
||
| ### Actor Update | ||
|
|
||
| - Maximize expected Q-value (mean of all quantiles) minus entropy bonus | ||
| - Same as SAC but using mean of quantile estimates | ||
|
|
||
| ### Entropy Tuning | ||
|
|
||
| - Same as SAC: automatically adjusts temperature parameter α | ||
|
|
||
| ## Differences from SAC | ||
|
|
||
| | Aspect | SAC | TQC | | ||
| | ----------------- | -------------- | ----------------------------- | | ||
| | Critic Output | Single Q-value | `n_quantiles` quantile values | | ||
| | Number of Critics | 2 (twin_q) | `n_critics` (configurable) | | ||
| | Loss Function | Huber/MSE | Quantile Huber Loss | | ||
| | Target Q | min(Q1, Q2) | Truncated sorted quantiles | | ||
|
|
||
| ## References | ||
|
|
||
| ```bibtex | ||
| @article{kuznetsov2020controlling, | ||
| title={Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics}, | ||
| author={Kuznetsov, Arsenii and Shvechikov, Pavel and Grishin, Alexander and Vetrov, Dmitry}, | ||
| journal={arXiv preprint arXiv:2005.04269}, | ||
| year={2020} | ||
| } | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| """TQC (Truncated Quantile Critics) Algorithm. | ||
|
|
||
| Paper: https://arxiv.org/abs/2005.04269 | ||
| """ | ||
|
|
||
| from ray.rllib.algorithms.tqc.tqc import TQC, TQCConfig | ||
| from ray.rllib.algorithms.tqc.tqc_catalog import TQCCatalog | ||
|
|
||
| __all__ = [ | ||
| "TQC", | ||
| "TQCConfig", | ||
| "TQCCatalog", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| """ | ||
| Default TQC RLModule. | ||
|
|
||
| TQC uses distributional critics with quantile regression. | ||
| """ | ||
|
|
||
| from typing import List, Tuple | ||
|
|
||
| from ray.rllib.core.learner.utils import make_target_network | ||
| from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI | ||
| from ray.rllib.core.rl_module.rl_module import RLModule | ||
| from ray.rllib.utils.annotations import ( | ||
| override, | ||
| ) | ||
| from ray.rllib.utils.typing import NetworkType | ||
| from ray.util.annotations import DeveloperAPI | ||
|
|
||
|
|
||
| @DeveloperAPI | ||
| class DefaultTQCRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI): | ||
| """RLModule for the TQC (Truncated Quantile Critics) algorithm. | ||
|
|
||
| TQC extends SAC by using distributional critics with quantile regression. | ||
| Each critic outputs n_quantiles values instead of a single Q-value. | ||
|
|
||
| Architecture: | ||
| - Policy (Actor): Same as SAC | ||
| [obs] -> [pi_encoder] -> [pi_head] -> [action_dist_inputs] | ||
|
|
||
| - Quantile Critics: Multiple critics, each outputting n_quantiles | ||
| [obs, action] -> [qf_encoder_i] -> [qf_head_i] -> [n_quantiles values] | ||
|
|
||
| - Target Quantile Critics: Target networks for each critic | ||
| [obs, action] -> [target_qf_encoder_i] -> [target_qf_head_i] -> [n_quantiles] | ||
| """ | ||
|
|
||
| @override(RLModule) | ||
| def setup(self): | ||
| # TQC-specific parameters from model_config | ||
| self.n_quantiles = self.model_config.get("n_quantiles", 25) | ||
| self.n_critics = self.model_config.get("n_critics", 2) | ||
| self.top_quantiles_to_drop_per_net = self.model_config.get( | ||
| "top_quantiles_to_drop_per_net", 2 | ||
| ) | ||
|
|
||
| # Total quantiles across all critics | ||
| self.quantiles_total = self.n_quantiles * self.n_critics | ||
|
|
||
| # Build the encoder for the policy (same as SAC) | ||
| self.pi_encoder = self.catalog.build_encoder(framework=self.framework) | ||
|
|
||
| if not self.inference_only or self.framework != "torch": | ||
| # Build multiple Q-function encoders and heads | ||
| self.qf_encoders = [] | ||
| self.qf_heads = [] | ||
|
|
||
| for i in range(self.n_critics): | ||
| qf_encoder = self.catalog.build_qf_encoder(framework=self.framework) | ||
| qf_head = self.catalog.build_qf_head(framework=self.framework) | ||
| self.qf_encoders.append(qf_encoder) | ||
| self.qf_heads.append(qf_head) | ||
|
|
||
| # Build the policy head (same as SAC) | ||
| self.pi = self.catalog.build_pi_head(framework=self.framework) | ||
|
|
||
| @override(TargetNetworkAPI) | ||
| def make_target_networks(self): | ||
| """Creates target networks for all quantile critics.""" | ||
| self.target_qf_encoders = [] | ||
| self.target_qf_heads = [] | ||
|
|
||
| for i in range(self.n_critics): | ||
| target_encoder = make_target_network(self.qf_encoders[i]) | ||
| target_head = make_target_network(self.qf_heads[i]) | ||
| self.target_qf_encoders.append(target_encoder) | ||
| self.target_qf_heads.append(target_head) | ||
|
|
||
| @override(InferenceOnlyAPI) | ||
| def get_non_inference_attributes(self) -> List[str]: | ||
| """Returns attributes not needed for inference.""" | ||
| return [ | ||
| "qf_encoders", | ||
| "qf_heads", | ||
| "target_qf_encoders", | ||
| "target_qf_heads", | ||
| ] | ||
|
|
||
| @override(TargetNetworkAPI) | ||
| def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: | ||
| """Returns pairs of (network, target_network) for updating targets.""" | ||
| pairs = [] | ||
| for i in range(self.n_critics): | ||
| pairs.append((self.qf_encoders[i], self.target_qf_encoders[i])) | ||
| pairs.append((self.qf_heads[i], self.target_qf_heads[i])) | ||
| return pairs | ||
|
|
||
| @override(RLModule) | ||
| def get_initial_state(self) -> dict: | ||
| """TQC does not support RNNs yet.""" | ||
| return {} | ||
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
DefaultTQCTorchRLModuleduplicates a significant amount of logic from this base class (e.g., insetup,make_target_networks,get_target_network_pairs) instead of inheriting from it. This leads to code duplication and makes maintenance harder.To improve this,
DefaultTQCTorchRLModuleshould inherit fromDefaultTQCRLModule. This would require some refactoring to make the base class methods more framework-agnostic (e.g., by not initializing lists for models directly, but letting subclasses handle it).For example,
setup()in the base class could be refactored to call framework-specific model creation methods. This would allow for better code reuse and a cleaner architecture, following the pattern seen in other RLlib algorithms like PPO.Additionally, the abstract method
_qf_forward_helperseems unused in the PyTorch implementation and could potentially be removed. The signature of_qf_forward_all_criticsalso differs between this base class and the PyTorch implementation, which would need to be reconciled if inheritance is used.