-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[RLlib] Add TQC (Truncated Quantile Critics) algorithm implementation #59808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new algorithm, TQC, to RLlib. The implementation is comprehensive, including the core algorithm logic, a Torch-specific RLModule and Learner, extensive tests, and excellent documentation. The code is well-structured and follows existing patterns in RLlib.
I have a few suggestions for improvement:
- There is some code duplication between the base
DefaultTQCRLModuleand theDefaultTQCTorchRLModule. I recommend refactoring to use inheritance to improve maintainability. - A minor inefficiency in
DefaultTQCRLModule.get_non_inference_attributescan be cleaned up. - An unused variable in
TQCTorchLearnershould be removed.
Overall, this is a great contribution and a solid implementation of the TQC algorithm.
|
|
||
|
|
||
| @DeveloperAPI | ||
| class DefaultTQCRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DefaultTQCTorchRLModule duplicates a significant amount of logic from this base class (e.g., in setup, make_target_networks, get_target_network_pairs) instead of inheriting from it. This leads to code duplication and makes maintenance harder.
To improve this, DefaultTQCTorchRLModule should inherit from DefaultTQCRLModule. This would require some refactoring to make the base class methods more framework-agnostic (e.g., by not initializing lists for models directly, but letting subclasses handle it).
For example, setup() in the base class could be refactored to call framework-specific model creation methods. This would allow for better code reuse and a cleaner architecture, following the pattern seen in other RLlib algorithms like PPO.
Additionally, the abstract method _qf_forward_helper seems unused in the PyTorch implementation and could potentially be removed. The signature of _qf_forward_all_critics also differs between this base class and the PyTorch implementation, which would need to be reconciled if inheritance is used.
| attrs = [] | ||
| for i in range(self.n_critics): | ||
| attrs.extend([ | ||
| f"qf_encoders", | ||
| f"qf_heads", | ||
| f"target_qf_encoders", | ||
| f"target_qf_heads", | ||
| ]) | ||
| return list(set(attrs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop in this method is inefficient as it adds the same strings multiple times to the attrs list, which are then deduplicated using set(). This can be simplified by directly returning a list of unique attribute names. The corresponding PyTorch implementation DefaultTQCTorchRLModule already does this correctly.
return [
"qf_encoders",
"qf_heads",
"target_qf_encoders",
"target_qf_heads",
]| # Get TQC parameters | ||
| n_quantiles = config.n_quantiles | ||
| n_critics = config.n_critics | ||
| top_quantiles_to_drop = config.top_quantiles_to_drop_per_net |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adds a new RLlib algorithm TQC, which extends SAC with distributional critics using quantile regression to control Q-function overestimation bias. Key components: - TQC algorithm configuration and implementation - Default TQC RLModule with multiple quantile critics - TQC catalog for building network components - Comprehensive test suite covering compilation, simple environments, and parameter validation - Documentation including Signed-off-by: tk42 <[email protected]>
…loss computation Removes unnecessary abstract methods and code from TQC implementation: - Removes unused `_qf_forward_helper` and `_qf_forward_all_critics` abstract methods from base module - Simplifies `get_non_inference_attributes` to return fixed list instead of dynamically building duplicates - Removes unnecessary dict observation handling (now handled by connectors) - Refactors quantile Huber loss to compute per-sample losses Signed-off-by: tk42 <[email protected]>
Removes unused `gymnasium` import from TQC torch module and learner, and removes unused `SACLearner` import from TQC torch learner. Signed-off-by: tk42 <[email protected]>
Applies code formatting improvements to TQC implementation: - Removes unused imports (`abstractmethod`, `Any`, `Dict`, `Encoder`, `Model`, `OverrideToImplementCustomLogic`) - Fixes line length violations by splitting long lines - Applies consistent formatting for exponentiation operators (`**`) - Removes trailing newline in default_tqc_rl_module.py Signed-off-by: tk42 <[email protected]>
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work @tk42 !! Thanks a lot for the contribution. There are only some minor formatting things to be checked - otherwise good to go!
| @property | ||
| @override(AlgorithmConfig) | ||
| def _model_config_auto_includes(self): | ||
| return super()._model_config_auto_includes | { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
| Returns: | ||
| Number of quantiles to use for target computation. | ||
| """ | ||
| config = self.config.get_config_for_module(module_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
| @override(TorchRLModule) | ||
| def setup(self): | ||
| # Get TQC-specific parameters | ||
| self.n_quantiles = self.model_config.get("n_quantiles", 25) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These attributes should already be assigned in the DefaultTQCRLModule. Can be removed here.
| # Build the policy encoder | ||
| self.pi_encoder = self.catalog.build_encoder(framework=self.framework) | ||
|
|
||
| if not self.inference_only or self.framework != "torch": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
| @override(InferenceOnlyAPI) | ||
| def get_non_inference_attributes(self): | ||
| """Returns attributes not needed for inference.""" | ||
| return [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!! APIs perfectly implemented!
| td_error = torch.abs(qf_preds_mean - target_mean) | ||
|
|
||
| # Log metrics | ||
| self.metrics.log_value( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
| def compute_gradients( | ||
| self, loss_per_module: Dict[ModuleID, TensorType], **kwargs | ||
| ) -> ParamDict: | ||
| """Computes gradients for each optimizer separately.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we simply call super here?
| train_batch_size=10, | ||
| ) | ||
| .env_runners( | ||
| env_to_module_connector=( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
| tune.register_env("simple_env_tqc_params", lambda config: SimpleEnv(config)) | ||
|
|
||
| # Test with different n_quantiles and n_critics | ||
| for n_quantiles, n_critics, top_drop in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
|
|
||
| def test_tqc_compilation(self): | ||
| """Test whether TQC can be built and trained.""" | ||
| config = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put the base config in the setUp method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tk42 Would be possible to add an example to rllib/examples/algorithms/tqc using Pendulum-v1 or Humanoid?
Adds example scripts for TQC algorithm on Pendulum and Humanoid environments, and refactors TQC module implementation: - Adds `pendulum_tqc.py` example with tuned hyperparameters for Pendulum-v1 - Adds `humanoid_tqc.py` example with tuned hyperparameters for Humanoid-v4 - Refactors TQC torch module to use shared constants (`QF_PREDS`, `QF_TARGET_NEXT`) from SAC learner - Moves TQC-specific parameter initialization to parent `setup()` method Signed-off-by: tk42 <[email protected]>
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the examples @tk42 . There is one issue with the TQCTorchRLModule which needs to be decided on and then we should be ready to go.
| TQC uses multiple quantile critics, each outputting n_quantiles values. | ||
| """ | ||
|
|
||
| framework: str = "torch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we here do one of the following:
- Either inherit directly from
TQCRLModule - Or implement all logic and (like in BC) have only a
TQCTorchRLModule?
| # and (if needed) use their values to set up `config` below. | ||
| args = parser.parse_args() | ||
|
|
||
| config = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks a lot for these examples!
…base class Refactors the TQC torch module to inherit from the framework-agnostic DefaultTQCRLModule base class: - Changes inheritance from multiple APIs (InferenceOnlyAPI, TargetNetworkAPI, QNetAPI) to DefaultTQCRLModule - Moves network building logic to parent class, keeping only PyTorch-specific ModuleList conversion in setup() - Removes duplicate method implementations now inherited from base class Signed-off-by: tk42 <[email protected]>
… TQC Adds validation to ensure `top_quantiles_to_drop_per_net` is non-negative in TQCConfig: - Raises ValueError if `top_quantiles_to_drop_per_net` < 0 - Validates parameter before checking total quantiles to drop Signed-off-by: tk42 <[email protected]>
|
I believe I’ve addressed all the points raised. Please let me know if I missed anything. |
simonsays1980
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the contribution of this cool algorithm @tk42 !
…ray-project#59808) Adds a new RLlib algorithm TQC, which extends SAC with distributional critics using quantile regression to control Q-function overestimation bias. Key components: - TQC algorithm configuration and implementation - Default TQC RLModule with multiple quantile critics - TQC catalog for building network components - Comprehensive test suite covering compilation, simple environments, and parameter validation - Documentation including > Thank you for contributing to Ray! 🚀 > Please review the [Ray Contribution Guide](https://docs.ray.io/en/master/ray-contribute/getting-involved.html) before opening a pull request. >⚠️ Remove these instructions before submitting your PR. > 💡 Tip: Mark as draft if you want early feedback, or ready for review when it's complete. ## Description > Briefly describe what this PR accomplishes and why it's needed. ## Related issues > Link related issues: "Fixes ray-project#1234", "Closes ray-project#1234", or "Related to ray-project#1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: tk42 <[email protected]> Co-authored-by: simonsays1980 <[email protected]> Signed-off-by: jasonwrwang <[email protected]>
Adds a new RLlib algorithm TQC, which extends SAC with distributional critics using quantile regression to control Q-function overestimation bias.
Key components:
Description
Related issues
Additional information