diff --git a/rllib/BUILD.bazel b/rllib/BUILD.bazel index 40b10b762444..663552842311 100644 --- a/rllib/BUILD.bazel +++ b/rllib/BUILD.bazel @@ -2166,6 +2166,18 @@ py_test( deps = [":conftest"], ) +# TQC +py_test( + name = "test_tqc", + size = "large", + srcs = ["algorithms/tqc/tests/test_tqc.py"], + tags = [ + "algorithms_dir", + "team:rllib", + ], + deps = [":conftest"], +) + # Generic testing py_test( name = "algorithms/tests/test_custom_resource", diff --git a/rllib/algorithms/tqc/README.md b/rllib/algorithms/tqc/README.md new file mode 100644 index 000000000000..4b2082e86823 --- /dev/null +++ b/rllib/algorithms/tqc/README.md @@ -0,0 +1,95 @@ +# TQC (Truncated Quantile Critics) + +## Overview + +TQC is an extension of SAC (Soft Actor-Critic) that uses distributional reinforcement learning with quantile regression to control overestimation bias in the Q-function. + +**Paper**: [Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics](https://arxiv.org/abs/2005.04269) + +## Key Features + +- **Distributional Critics**: Each critic network outputs multiple quantiles instead of a single Q-value +- **Multiple Critics**: Uses `n_critics` independent critic networks (default: 2) +- **Truncated Targets**: Drops the top quantiles when computing target Q-values to reduce overestimation +- **Quantile Huber Loss**: Uses quantile regression with Huber loss for critic training + +## Usage + +```python +from ray.rllib.algorithms.tqc import TQCConfig + +config = ( + TQCConfig() + .environment("Pendulum-v1") + .training( + n_quantiles=25, # Number of quantiles per critic + n_critics=2, # Number of critic networks + top_quantiles_to_drop_per_net=2, # Quantiles to drop for bias control + ) +) + +algo = config.build() +for _ in range(100): + result = algo.train() + print(f"Episode reward mean: {result['env_runners']['episode_reward_mean']}") +``` + +## Configuration + +### TQC-Specific Parameters + +| Parameter | Default | Description | +| ------------------------------- | ------- | ------------------------------------------------------------------ | +| `n_quantiles` | 25 | Number of quantiles for each critic network | +| `n_critics` | 2 | Number of critic networks | +| `top_quantiles_to_drop_per_net` | 2 | Number of top quantiles to drop per network when computing targets | + +### Inherited from SAC + +TQC inherits all SAC parameters including: + +- `actor_lr`, `critic_lr`, `alpha_lr`: Learning rates +- `tau`: Target network update coefficient +- `initial_alpha`: Initial entropy coefficient +- `target_entropy`: Target entropy for automatic alpha tuning + +## Algorithm Details + +### Critic Update + +1. Each critic outputs `n_quantiles` quantile estimates +2. For target computation: + - Collect all quantiles from all critics: `n_critics * n_quantiles` values + - Sort all quantiles + - Drop the top `top_quantiles_to_drop_per_net * n_critics` quantiles + - Use remaining quantiles as targets +3. Train critics using quantile Huber loss + +### Actor Update + +- Maximize expected Q-value (mean of all quantiles) minus entropy bonus +- Same as SAC but using mean of quantile estimates + +### Entropy Tuning + +- Same as SAC: automatically adjusts temperature parameter α + +## Differences from SAC + +| Aspect | SAC | TQC | +| ----------------- | -------------- | ----------------------------- | +| Critic Output | Single Q-value | `n_quantiles` quantile values | +| Number of Critics | 2 (twin_q) | `n_critics` (configurable) | +| Loss Function | Huber/MSE | Quantile Huber Loss | +| Target Q | min(Q1, Q2) | Truncated sorted quantiles | + +## References + +```bibtex +@article{kuznetsov2020controlling, + title={Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics}, + author={Kuznetsov, Arsenii and Shvechikov, Pavel and Grishin, Alexander and Vetrov, Dmitry}, + journal={arXiv preprint arXiv:2005.04269}, + year={2020} +} +``` diff --git a/rllib/algorithms/tqc/__init__.py b/rllib/algorithms/tqc/__init__.py new file mode 100644 index 000000000000..c0368262c39d --- /dev/null +++ b/rllib/algorithms/tqc/__init__.py @@ -0,0 +1,13 @@ +"""TQC (Truncated Quantile Critics) Algorithm. + +Paper: https://arxiv.org/abs/2005.04269 +""" + +from ray.rllib.algorithms.tqc.tqc import TQC, TQCConfig +from ray.rllib.algorithms.tqc.tqc_catalog import TQCCatalog + +__all__ = [ + "TQC", + "TQCConfig", + "TQCCatalog", +] diff --git a/rllib/algorithms/tqc/default_tqc_rl_module.py b/rllib/algorithms/tqc/default_tqc_rl_module.py new file mode 100644 index 000000000000..610e9e6bee47 --- /dev/null +++ b/rllib/algorithms/tqc/default_tqc_rl_module.py @@ -0,0 +1,100 @@ +""" +Default TQC RLModule. + +TQC uses distributional critics with quantile regression. +""" + +from typing import List, Tuple + +from ray.rllib.core.learner.utils import make_target_network +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, +) +from ray.rllib.utils.typing import NetworkType +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultTQCRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI): + """RLModule for the TQC (Truncated Quantile Critics) algorithm. + + TQC extends SAC by using distributional critics with quantile regression. + Each critic outputs n_quantiles values instead of a single Q-value. + + Architecture: + - Policy (Actor): Same as SAC + [obs] -> [pi_encoder] -> [pi_head] -> [action_dist_inputs] + + - Quantile Critics: Multiple critics, each outputting n_quantiles + [obs, action] -> [qf_encoder_i] -> [qf_head_i] -> [n_quantiles values] + + - Target Quantile Critics: Target networks for each critic + [obs, action] -> [target_qf_encoder_i] -> [target_qf_head_i] -> [n_quantiles] + """ + + @override(RLModule) + def setup(self): + # TQC-specific parameters from model_config + self.n_quantiles = self.model_config.get("n_quantiles", 25) + self.n_critics = self.model_config.get("n_critics", 2) + self.top_quantiles_to_drop_per_net = self.model_config.get( + "top_quantiles_to_drop_per_net", 2 + ) + + # Total quantiles across all critics + self.quantiles_total = self.n_quantiles * self.n_critics + + # Build the encoder for the policy (same as SAC) + self.pi_encoder = self.catalog.build_encoder(framework=self.framework) + + if not self.inference_only or self.framework != "torch": + # Build multiple Q-function encoders and heads + self.qf_encoders = [] + self.qf_heads = [] + + for i in range(self.n_critics): + qf_encoder = self.catalog.build_qf_encoder(framework=self.framework) + qf_head = self.catalog.build_qf_head(framework=self.framework) + self.qf_encoders.append(qf_encoder) + self.qf_heads.append(qf_head) + + # Build the policy head (same as SAC) + self.pi = self.catalog.build_pi_head(framework=self.framework) + + @override(TargetNetworkAPI) + def make_target_networks(self): + """Creates target networks for all quantile critics.""" + self.target_qf_encoders = [] + self.target_qf_heads = [] + + for i in range(self.n_critics): + target_encoder = make_target_network(self.qf_encoders[i]) + target_head = make_target_network(self.qf_heads[i]) + self.target_qf_encoders.append(target_encoder) + self.target_qf_heads.append(target_head) + + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + """Returns attributes not needed for inference.""" + return [ + "qf_encoders", + "qf_heads", + "target_qf_encoders", + "target_qf_heads", + ] + + @override(TargetNetworkAPI) + def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: + """Returns pairs of (network, target_network) for updating targets.""" + pairs = [] + for i in range(self.n_critics): + pairs.append((self.qf_encoders[i], self.target_qf_encoders[i])) + pairs.append((self.qf_heads[i], self.target_qf_heads[i])) + return pairs + + @override(RLModule) + def get_initial_state(self) -> dict: + """TQC does not support RNNs yet.""" + return {} diff --git a/rllib/algorithms/tqc/tests/__init__.py b/rllib/algorithms/tqc/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/algorithms/tqc/tests/test_tqc.py b/rllib/algorithms/tqc/tests/test_tqc.py new file mode 100644 index 000000000000..b478fe01b32d --- /dev/null +++ b/rllib/algorithms/tqc/tests/test_tqc.py @@ -0,0 +1,231 @@ +"""Tests for the TQC (Truncated Quantile Critics) algorithm.""" + +import unittest + +import gymnasium as gym +import numpy as np +from gymnasium.spaces import Box, Dict, Discrete, Tuple + +import ray +from ray import tune +from ray.rllib.algorithms import tqc +from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations +from ray.rllib.examples.envs.classes.random_env import RandomEnv +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.test_utils import check_train_results_new_api_stack + +torch, _ = try_import_torch() + + +class SimpleEnv(gym.Env): + """Simple continuous control environment for testing.""" + + def __init__(self, config): + self.action_space = Box(0.0, 1.0, (1,)) + self.observation_space = Box(0.0, 1.0, (1,)) + self.max_steps = config.get("max_steps", 100) + self.state = None + self.steps = None + + def reset(self, *, seed=None, options=None): + self.state = self.observation_space.sample() + self.steps = 0 + return self.state, {} + + def step(self, action): + self.steps += 1 + # Reward is 1.0 - (max(actions) - state). + [rew] = 1.0 - np.abs(np.max(action) - self.state) + terminated = False + truncated = self.steps >= self.max_steps + self.state = self.observation_space.sample() + return self.state, rew, terminated, truncated, {} + + +class TestTQC(unittest.TestCase): + """Test cases for TQC algorithm.""" + + @classmethod + def setUpClass(cls) -> None: + np.random.seed(42) + torch.manual_seed(42) + ray.init() + + @classmethod + def tearDownClass(cls) -> None: + ray.shutdown() + + def setUp(self) -> None: + """Set up base config for tests.""" + self.base_config = ( + tqc.TQCConfig() + .training( + n_step=3, + n_quantiles=25, + n_critics=2, + top_quantiles_to_drop_per_net=2, + replay_buffer_config={ + "capacity": 40000, + }, + num_steps_sampled_before_learning_starts=0, + store_buffer_in_checkpoints=True, + train_batch_size=10, + ) + .env_runners( + num_env_runners=0, + rollout_fragment_length=10, + ) + ) + + def test_tqc_compilation(self): + """Test whether TQC can be built and trained.""" + config = self.base_config.copy().env_runners( + env_to_module_connector=(lambda env, spaces, device: FlattenObservations()), + ) + num_iterations = 1 + + image_space = Box(-1.0, 1.0, shape=(84, 84, 3)) + simple_space = Box(-1.0, 1.0, shape=(3,)) + + tune.register_env( + "random_dict_env_tqc", + lambda _: RandomEnv( + { + "observation_space": Dict( + { + "a": simple_space, + "b": Discrete(2), + "c": image_space, + } + ), + "action_space": Box(-1.0, 1.0, shape=(1,)), + } + ), + ) + tune.register_env( + "random_tuple_env_tqc", + lambda _: RandomEnv( + { + "observation_space": Tuple( + [simple_space, Discrete(2), image_space] + ), + "action_space": Box(-1.0, 1.0, shape=(1,)), + } + ), + ) + + # Test for different env types (dict and tuple observations). + for env in [ + "random_dict_env_tqc", + "random_tuple_env_tqc", + ]: + print("Env={}".format(env)) + config.environment(env) + algo = config.build() + for i in range(num_iterations): + results = algo.train() + check_train_results_new_api_stack(results) + print(results) + + algo.stop() + + def test_tqc_simple_env(self): + """Test TQC on a simple continuous control environment.""" + tune.register_env("simple_env_tqc", lambda config: SimpleEnv(config)) + + config = ( + tqc.TQCConfig() + .environment("simple_env_tqc", env_config={"max_steps": 50}) + .training( + n_quantiles=10, + n_critics=2, + top_quantiles_to_drop_per_net=1, + replay_buffer_config={ + "capacity": 10000, + }, + num_steps_sampled_before_learning_starts=0, + train_batch_size=32, + ) + .env_runners( + num_env_runners=0, + rollout_fragment_length=10, + ) + ) + + algo = config.build() + for _ in range(2): + results = algo.train() + check_train_results_new_api_stack(results) + print(results) + + algo.stop() + + def test_tqc_quantile_parameters(self): + """Test TQC with different quantile configurations.""" + tune.register_env("simple_env_tqc_params", lambda config: SimpleEnv(config)) + + # Test with different n_quantiles and n_critics + for n_quantiles, n_critics, top_drop in [ + (5, 2, 1), + (25, 3, 2), + (50, 2, 5), + ]: + print( + f"Testing n_quantiles={n_quantiles}, n_critics={n_critics}, " + f"top_drop={top_drop}" + ) + + config = ( + tqc.TQCConfig() + .environment("simple_env_tqc_params", env_config={"max_steps": 20}) + .training( + n_quantiles=n_quantiles, + n_critics=n_critics, + top_quantiles_to_drop_per_net=top_drop, + replay_buffer_config={ + "capacity": 5000, + }, + num_steps_sampled_before_learning_starts=0, + train_batch_size=16, + ) + .env_runners( + num_env_runners=0, + rollout_fragment_length=5, + ) + ) + + algo = config.build() + results = algo.train() + check_train_results_new_api_stack(results) + algo.stop() + + def test_tqc_config_validation(self): + """Test that TQC config validation works correctly.""" + # Test invalid n_quantiles + with self.assertRaises(ValueError): + config = tqc.TQCConfig().training(n_quantiles=0) + config.validate() + + # Test invalid n_critics + with self.assertRaises(ValueError): + config = tqc.TQCConfig().training(n_critics=0) + config.validate() + + # Test dropping too many quantiles + with self.assertRaises(ValueError): + # With n_quantiles=5, n_critics=2, total=10 + # Dropping 6 per net = 12 total, which is > 10 + config = tqc.TQCConfig().training( + n_quantiles=5, + n_critics=2, + top_quantiles_to_drop_per_net=6, + ) + config.validate() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tqc/torch/__init__.py b/rllib/algorithms/tqc/torch/__init__.py new file mode 100644 index 000000000000..380acbe89e8d --- /dev/null +++ b/rllib/algorithms/tqc/torch/__init__.py @@ -0,0 +1,11 @@ +"""TQC PyTorch implementations.""" + +from ray.rllib.algorithms.tqc.torch.default_tqc_torch_rl_module import ( + DefaultTQCTorchRLModule, +) +from ray.rllib.algorithms.tqc.torch.tqc_torch_learner import TQCTorchLearner + +__all__ = [ + "DefaultTQCTorchRLModule", + "TQCTorchLearner", +] diff --git a/rllib/algorithms/tqc/torch/default_tqc_torch_rl_module.py b/rllib/algorithms/tqc/torch/default_tqc_torch_rl_module.py new file mode 100644 index 000000000000..52d3fa319340 --- /dev/null +++ b/rllib/algorithms/tqc/torch/default_tqc_torch_rl_module.py @@ -0,0 +1,235 @@ +""" +PyTorch implementation of the TQC RLModule. +""" + +from typing import Any, Dict + +from ray.rllib.algorithms.sac.sac_learner import QF_PREDS, QF_TARGET_NEXT +from ray.rllib.algorithms.tqc.default_tqc_rl_module import DefaultTQCRLModule +from ray.rllib.algorithms.tqc.tqc_catalog import TQCCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.utils import make_target_network +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + + +class DefaultTQCTorchRLModule(TorchRLModule, DefaultTQCRLModule): + """PyTorch implementation of the TQC RLModule. + + TQC uses multiple quantile critics, each outputting n_quantiles values. + """ + + framework: str = "torch" + + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = TQCCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(DefaultTQCRLModule) + def setup(self): + # Call parent setup to initialize TQC-specific parameters and build networks + super().setup() + + # Convert lists to nn.ModuleList for proper PyTorch parameter tracking + if not self.inference_only or self.framework != "torch": + self.qf_encoders = nn.ModuleList(self.qf_encoders) + self.qf_heads = nn.ModuleList(self.qf_heads) + + @override(DefaultTQCRLModule) + def make_target_networks(self): + """Creates target networks for all quantile critics.""" + self.target_qf_encoders = nn.ModuleList() + self.target_qf_heads = nn.ModuleList() + + for i in range(self.n_critics): + target_encoder = make_target_network(self.qf_encoders[i]) + target_head = make_target_network(self.qf_heads[i]) + self.target_qf_encoders.append(target_encoder) + self.target_qf_heads.append(target_head) + + @override(TorchRLModule) + def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Forward pass for inference (action selection). + + Same as SAC - samples actions from the policy. + """ + output = {} + + # Extract features from observations + pi_encoder_out = self.pi_encoder(batch) + pi_out = self.pi(pi_encoder_out[ENCODER_OUT]) + + output[Columns.ACTION_DIST_INPUTS] = pi_out + + return output + + @override(TorchRLModule) + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Forward pass for exploration. + + Same as inference for TQC (stochastic policy). + """ + return self._forward_inference(batch) + + @override(TorchRLModule) + def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Forward pass for training. + + Computes: + - Action distribution inputs from current observations + - Q-values (quantiles) for current state-action pairs + - Q-values (quantiles) for next states with resampled actions + """ + output = {} + + # Get action distribution inputs for current observations + pi_encoder_out = self.pi_encoder(batch) + pi_out = self.pi(pi_encoder_out[ENCODER_OUT]) + output[Columns.ACTION_DIST_INPUTS] = pi_out + + # Sample actions from current policy for current observations + action_dist_class = self.catalog.get_action_dist_cls(framework=self.framework) + action_dist_curr = action_dist_class.from_logits(pi_out) + actions_curr = action_dist_curr.rsample() + logp_curr = action_dist_curr.logp(actions_curr) + + output["actions_curr"] = actions_curr + output["logp_curr"] = logp_curr + + # Compute Q-values for actions from replay buffer + qf_out = self._qf_forward_all_critics( + batch[Columns.OBS], + batch[Columns.ACTIONS], + use_target=False, + ) + output[QF_PREDS] = qf_out # (batch, n_critics, n_quantiles) + + # Compute Q-values for resampled actions (for actor loss) + qf_curr = self._qf_forward_all_critics( + batch[Columns.OBS], + actions_curr, + use_target=False, + ) + output["qf_curr"] = qf_curr + + # For next state Q-values (target computation) + if Columns.NEXT_OBS in batch: + # Get action distribution for next observations + pi_encoder_out_next = self.pi_encoder( + {Columns.OBS: batch[Columns.NEXT_OBS]} + ) + pi_out_next = self.pi(pi_encoder_out_next[ENCODER_OUT]) + + # Sample actions for next state + action_dist_next = action_dist_class.from_logits(pi_out_next) + actions_next = action_dist_next.rsample() + logp_next = action_dist_next.logp(actions_next) + + output["actions_next"] = actions_next + output["logp_next"] = logp_next + + # Compute target Q-values for next state + qf_target_next = self._qf_forward_all_critics( + batch[Columns.NEXT_OBS], + actions_next, + use_target=True, + ) + output[QF_TARGET_NEXT] = qf_target_next + + return output + + def _qf_forward_all_critics( + self, + obs: torch.Tensor, + actions: torch.Tensor, + use_target: bool = False, + ) -> torch.Tensor: + """Forward pass through all critic networks. + + Args: + obs: Observations tensor. + actions: Actions tensor. + use_target: Whether to use target networks. + + Returns: + Stacked quantile values from all critics. + Shape: (batch_size, n_critics, n_quantiles) + """ + # Note: obs should already be a flat tensor at this point. + # Dict observations are handled by connectors (e.g., FlattenObservations) + # before reaching this method. + + # Concatenate observations and actions + qf_input = torch.cat([obs, actions], dim=-1) + batch_dict = {Columns.OBS: qf_input} + + encoders = self.target_qf_encoders if use_target else self.qf_encoders + heads = self.target_qf_heads if use_target else self.qf_heads + + quantiles_list = [] + for encoder, head in zip(encoders, heads): + encoder_out = encoder(batch_dict) + quantiles = head(encoder_out[ENCODER_OUT]) # (batch, n_quantiles) + quantiles_list.append(quantiles) + + # Stack: (batch, n_critics, n_quantiles) + return torch.stack(quantiles_list, dim=1) + + @override(DefaultTQCRLModule) + def compute_q_values(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Computes Q-values (mean of quantiles) for the given batch. + + Args: + batch: Dict containing observations and actions. + + Returns: + Mean Q-value across all quantiles and critics. + """ + obs = batch[Columns.OBS] + actions = batch[Columns.ACTIONS] + + # Get all quantiles from all critics + quantiles = self._qf_forward_all_critics(obs, actions, use_target=False) + + # Return mean across all quantiles and critics + return quantiles.mean(dim=(1, 2)) + + @override(DefaultTQCRLModule) + def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Forward pass through target networks. + + Args: + batch: Dict containing observations and actions. + + Returns: + Target Q-values (mean of truncated quantiles). + """ + obs = batch[Columns.OBS] + actions = batch[Columns.ACTIONS] + + # Get all quantiles from target critics + quantiles = self._qf_forward_all_critics(obs, actions, use_target=True) + + # Flatten, sort, and truncate top quantiles + batch_size = quantiles.shape[0] + quantiles_flat = quantiles.reshape(batch_size, -1) + quantiles_sorted, _ = torch.sort(quantiles_flat, dim=1) + + # Calculate number of quantiles to keep + n_target_quantiles = ( + self.quantiles_total - self.top_quantiles_to_drop_per_net * self.n_critics + ) + quantiles_truncated = quantiles_sorted[:, :n_target_quantiles] + + # Return mean of truncated quantiles + return quantiles_truncated.mean(dim=1) + + @staticmethod + def _get_catalog_class(): + return TQCCatalog diff --git a/rllib/algorithms/tqc/torch/tqc_torch_learner.py b/rllib/algorithms/tqc/torch/tqc_torch_learner.py new file mode 100644 index 000000000000..f74818e47573 --- /dev/null +++ b/rllib/algorithms/tqc/torch/tqc_torch_learner.py @@ -0,0 +1,286 @@ +""" +PyTorch implementation of the TQC Learner. + +Implements the TQC loss computation with quantile Huber loss. +""" + +from typing import Any, Dict + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.sac.sac_learner import ( + LOGPS_KEY, + QF_PREDS, + QF_TARGET_NEXT, +) +from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner +from ray.rllib.algorithms.tqc.tqc import TQCConfig +from ray.rllib.algorithms.tqc.tqc_learner import ( + QF_LOSS_KEY, + QF_MAX_KEY, + QF_MEAN_KEY, + QF_MIN_KEY, + TD_ERROR_MEAN_KEY, + TQCLearner, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import POLICY_LOSS_KEY +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import TD_ERROR_KEY +from ray.rllib.utils.typing import ModuleID, TensorType + +torch, nn = try_import_torch() + + +def quantile_huber_loss_per_sample( + quantiles: torch.Tensor, + target_quantiles: torch.Tensor, + kappa: float = 1.0, +) -> torch.Tensor: + """Computes the quantile Huber loss per sample (for importance sampling). + + Args: + quantiles: Current quantile estimates. Shape: (batch, n_quantiles) + target_quantiles: Target quantile values. Shape: (batch, n_target_quantiles) + kappa: Huber loss threshold parameter. + + Returns: + Per-sample quantile Huber loss. Shape: (batch,) + """ + n_quantiles = quantiles.shape[1] + + # Compute cumulative probabilities for quantiles (tau values) + tau = ( + torch.arange(n_quantiles, device=quantiles.device, dtype=quantiles.dtype) + 0.5 + ) / n_quantiles + + # Expand dimensions for broadcasting + quantiles_expanded = quantiles.unsqueeze(2) + target_expanded = target_quantiles.unsqueeze(1) + + # Compute pairwise TD errors: (batch, n_quantiles, n_target_quantiles) + td_error = target_expanded - quantiles_expanded + + # Compute Huber loss element-wise using nn.HuberLoss + huber_loss_fn = nn.HuberLoss(reduction="none", delta=kappa) + huber_loss = huber_loss_fn(quantiles_expanded, target_expanded) + + # Compute quantile weights + tau_expanded = tau.view(1, n_quantiles, 1) + quantile_weight = torch.abs(tau_expanded - (td_error < 0).float()) + + # Weighted Huber loss + quantile_huber = quantile_weight * huber_loss + + # Sum over quantile dimensions, keep batch dimension + return quantile_huber.sum(dim=(1, 2)) + + +class TQCTorchLearner(SACTorchLearner, TQCLearner): + """PyTorch Learner for TQC algorithm. + + Implements the TQC loss computation: + - Critic loss: Quantile Huber loss with truncated targets + - Actor loss: Maximize mean Q-value (from truncated quantiles) + - Alpha loss: Same as SAC (entropy regularization) + """ + + @override(SACTorchLearner) + def build(self) -> None: + super().build() + self._temp_losses = {} + + @override(SACTorchLearner) + def configure_optimizers_for_module( + self, module_id: ModuleID, config: AlgorithmConfig = None + ) -> None: + """Configures optimizers for TQC. + + TQC has separate optimizers for: + - All critic networks (shared optimizer) + - Actor network + - Temperature (alpha) parameter + """ + module = self._module[module_id] + + # Collect all critic parameters + critic_params = [] + for encoder in module.qf_encoders: + critic_params.extend(self.get_parameters(encoder)) + for head in module.qf_heads: + critic_params.extend(self.get_parameters(head)) + + optim_critic = torch.optim.Adam(critic_params, eps=1e-7) + self.register_optimizer( + module_id=module_id, + optimizer_name="qf", + optimizer=optim_critic, + params=critic_params, + lr_or_lr_schedule=config.critic_lr, + ) + + # Actor optimizer + params_actor = self.get_parameters(module.pi_encoder) + self.get_parameters( + module.pi + ) + optim_actor = torch.optim.Adam(params_actor, eps=1e-7) + self.register_optimizer( + module_id=module_id, + optimizer_name="policy", + optimizer=optim_actor, + params=params_actor, + lr_or_lr_schedule=config.actor_lr, + ) + + # Temperature optimizer + temperature = self.curr_log_alpha[module_id] + optim_temperature = torch.optim.Adam([temperature], eps=1e-7) + self.register_optimizer( + module_id=module_id, + optimizer_name="alpha", + optimizer=optim_temperature, + params=[temperature], + lr_or_lr_schedule=config.alpha_lr, + ) + + @override(SACTorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: TQCConfig, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + """Computes the TQC loss. + + Args: + module_id: The module ID. + config: The TQC configuration. + batch: The training batch. + fwd_out: Forward pass outputs. + + Returns: + Total loss (sum of critic, actor, and alpha losses). + """ + # Get current alpha (temperature parameter) + alpha = torch.exp(self.curr_log_alpha[module_id]) + + # Get TQC parameters + n_critics = config.n_critics + n_target_quantiles = self._get_n_target_quantiles(module_id) + + batch_size = batch[Columns.OBS].shape[0] + + # === Critic Loss === + # Get current Q-value predictions (quantiles) + # Shape: (batch, n_critics, n_quantiles) + qf_preds = fwd_out[QF_PREDS] + + # Get target Q-values for next state + # Shape: (batch, n_critics, n_quantiles) + qf_target_next = fwd_out[QF_TARGET_NEXT] + logp_next = fwd_out["logp_next"] + + # Flatten and sort quantiles across all critics + # Shape: (batch, n_critics * n_quantiles) + qf_target_next_flat = qf_target_next.reshape(batch_size, -1) + + # Sort and truncate top quantiles to control overestimation + qf_target_next_sorted, _ = torch.sort(qf_target_next_flat, dim=1) + qf_target_next_truncated = qf_target_next_sorted[:, :n_target_quantiles] + + # Compute target with entropy bonus + # Shape: (batch, n_target_quantiles) + target_quantiles = ( + qf_target_next_truncated - alpha.detach() * logp_next.unsqueeze(1) + ) + + # Compute TD targets + rewards = batch[Columns.REWARDS].unsqueeze(1) + terminateds = batch[Columns.TERMINATEDS].float().unsqueeze(1) + gamma = config.gamma + n_step = batch.get("n_step", torch.ones_like(batch[Columns.REWARDS])) + if isinstance(n_step, (int, float)): + n_step = torch.full_like(batch[Columns.REWARDS], n_step) + + target_quantiles = ( + rewards + + (1.0 - terminateds) * (gamma ** n_step.unsqueeze(1)) * target_quantiles + ).detach() + + # Get importance sampling weights for prioritized replay + weights = batch.get("weights", torch.ones_like(batch[Columns.REWARDS])) + + # Compute critic loss for each critic + critic_loss = torch.tensor(0.0, device=qf_preds.device) + for i in range(n_critics): + # Get quantiles for this critic: (batch, n_quantiles) + critic_quantiles = qf_preds[:, i, :] + # Compute per-sample quantile huber loss + critic_loss_per_sample = quantile_huber_loss_per_sample( + critic_quantiles, + target_quantiles, + ) + # Apply importance sampling weights + critic_loss += torch.mean(weights * critic_loss_per_sample) + + # === Actor Loss === + # Get Q-values for resampled actions + qf_curr = fwd_out["qf_curr"] # (batch, n_critics, n_quantiles) + logp_curr = fwd_out["logp_curr"] + + # Mean over all quantiles and critics + qf_curr_mean = qf_curr.mean(dim=(1, 2)) + + # Actor loss: maximize Q-value while maintaining entropy + actor_loss = (alpha.detach() * logp_curr - qf_curr_mean).mean() + + # === Alpha Loss === + alpha_loss = -torch.mean( + self.curr_log_alpha[module_id] + * (logp_curr.detach() + self.target_entropy[module_id]) + ) + + # Total loss + total_loss = critic_loss + actor_loss + alpha_loss + + # Compute TD error for prioritized replay + # Use mean across critics and quantiles + qf_preds_mean = qf_preds.mean(dim=(1, 2)) + target_mean = target_quantiles.mean(dim=1) + td_error = torch.abs(qf_preds_mean - target_mean) + + # Log metrics + self.metrics.log_value( + key=(module_id, TD_ERROR_KEY), + value=td_error, + reduce="item_series", + ) + + self.metrics.log_dict( + { + POLICY_LOSS_KEY: actor_loss, + QF_LOSS_KEY: critic_loss, + "alpha_loss": alpha_loss, + "alpha_value": alpha[0], + "log_alpha_value": torch.log(alpha)[0], + "target_entropy": self.target_entropy[module_id], + LOGPS_KEY: torch.mean(logp_curr), + QF_MEAN_KEY: torch.mean(qf_preds), + QF_MAX_KEY: torch.max(qf_preds), + QF_MIN_KEY: torch.min(qf_preds), + TD_ERROR_MEAN_KEY: torch.mean(td_error), + }, + key=module_id, + window=1, + ) + + # Store losses for gradient computation + self._temp_losses[(module_id, POLICY_LOSS_KEY)] = actor_loss + self._temp_losses[(module_id, QF_LOSS_KEY)] = critic_loss + self._temp_losses[(module_id, "alpha_loss")] = alpha_loss + + return total_loss + + # Note: compute_gradients is inherited from SACTorchLearner diff --git a/rllib/algorithms/tqc/tqc.py b/rllib/algorithms/tqc/tqc.py new file mode 100644 index 000000000000..6a0c2eafdeb2 --- /dev/null +++ b/rllib/algorithms/tqc/tqc.py @@ -0,0 +1,171 @@ +""" +TQC (Truncated Quantile Critics) Algorithm. + +Paper: https://arxiv.org/abs/2005.04269 +"Controlling Overestimation Bias with Truncated Mixture of Continuous +Distributional Quantile Critics" + +TQC extends SAC by using distributional RL with quantile regression to +control overestimation bias in the Q-function. +""" + +import logging +from typing import Optional, Type, Union + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.sac.sac import SAC, SACConfig +from ray.rllib.core.learner import Learner +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import RLModuleSpecType + +logger = logging.getLogger(__name__) + + +class TQCConfig(SACConfig): + """Configuration for the TQC algorithm. + + TQC extends SAC with distributional critics using quantile regression. + + Example: + >>> from ray.rllib.algorithms.tqc import TQCConfig + >>> config = ( + ... TQCConfig() + ... .environment("Pendulum-v1") + ... .training( + ... n_quantiles=25, + ... n_critics=2, + ... top_quantiles_to_drop_per_net=2, + ... ) + ... ) + >>> algo = config.build() + """ + + def __init__(self, algo_class=None): + """Initializes a TQCConfig instance.""" + super().__init__(algo_class=algo_class or TQC) + + # TQC-specific parameters + self.n_quantiles = 25 + self.n_critics = 2 + self.top_quantiles_to_drop_per_net = 2 + + @override(SACConfig) + def training( + self, + *, + n_quantiles: Optional[int] = NotProvided, + n_critics: Optional[int] = NotProvided, + top_quantiles_to_drop_per_net: Optional[int] = NotProvided, + **kwargs, + ): + """Sets the training-related configuration. + + Args: + n_quantiles: Number of quantiles for each critic network. + Default is 25. + n_critics: Number of critic networks. Default is 2. + top_quantiles_to_drop_per_net: Number of quantiles to drop per + network when computing the target Q-value. This controls + the overestimation bias. Default is 2. + **kwargs: Additional arguments passed to SACConfig.training(). + + Returns: + This updated TQCConfig object. + """ + super().training(**kwargs) + + if n_quantiles is not NotProvided: + self.n_quantiles = n_quantiles + if n_critics is not NotProvided: + self.n_critics = n_critics + if top_quantiles_to_drop_per_net is not NotProvided: + self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + """Validates the TQC configuration.""" + super().validate() + + # Validate TQC-specific parameters + if self.n_quantiles < 1: + raise ValueError(f"`n_quantiles` must be >= 1, got {self.n_quantiles}") + if self.n_critics < 1: + raise ValueError(f"`n_critics` must be >= 1, got {self.n_critics}") + + # Ensure top_quantiles_to_drop_per_net is non-negative + if self.top_quantiles_to_drop_per_net < 0: + raise ValueError( + f"`top_quantiles_to_drop_per_net` must be >= 0, got " + f"{self.top_quantiles_to_drop_per_net}" + ) + + # Ensure we don't drop more quantiles than we have + total_quantiles = self.n_quantiles * self.n_critics + quantiles_to_drop = self.top_quantiles_to_drop_per_net * self.n_critics + if quantiles_to_drop >= total_quantiles: + raise ValueError( + f"Cannot drop {quantiles_to_drop} quantiles when only " + f"{total_quantiles} total quantiles are available. " + f"Reduce `top_quantiles_to_drop_per_net` or increase " + f"`n_quantiles` or `n_critics`." + ) + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + if self.framework_str == "torch": + from ray.rllib.algorithms.tqc.torch.default_tqc_torch_rl_module import ( + DefaultTQCTorchRLModule, + ) + + return RLModuleSpec(module_class=DefaultTQCTorchRLModule) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. Use `torch`." + ) + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + from ray.rllib.algorithms.tqc.torch.tqc_torch_learner import ( + TQCTorchLearner, + ) + + return TQCTorchLearner + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. Use `torch`." + ) + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self): + return super()._model_config_auto_includes | { + "n_quantiles": self.n_quantiles, + "n_critics": self.n_critics, + "top_quantiles_to_drop_per_net": self.top_quantiles_to_drop_per_net, + } + + +class TQC(SAC): + """TQC (Truncated Quantile Critics) Algorithm. + + TQC extends SAC by using distributional critics with quantile regression + and truncating the top quantiles to control overestimation bias. + + Key differences from SAC: + - Uses multiple critic networks, each outputting multiple quantiles + - Computes target Q-values by sorting and truncating top quantiles + - Uses quantile Huber loss for critic training + + See the paper for more details: + https://arxiv.org/abs/2005.04269 + """ + + @classmethod + @override(Algorithm) + def get_default_config(cls) -> TQCConfig: + return TQCConfig() diff --git a/rllib/algorithms/tqc/tqc_catalog.py b/rllib/algorithms/tqc/tqc_catalog.py new file mode 100644 index 000000000000..61e89f252df3 --- /dev/null +++ b/rllib/algorithms/tqc/tqc_catalog.py @@ -0,0 +1,89 @@ +""" +TQC Catalog for building TQC-specific models. + +TQC uses multiple quantile critics, each outputting n_quantiles values. +""" + +import gymnasium as gym + +from ray.rllib.algorithms.sac.sac_catalog import SACCatalog +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.configs import MLPHeadConfig +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + + +class TQCCatalog(SACCatalog): + """Catalog class for building TQC models. + + TQC extends SAC by using distributional critics with quantile regression. + Each critic outputs `n_quantiles` values instead of a single Q-value. + + The catalog builds: + - Pi Encoder: Same as SAC (encodes observations for the actor) + - Pi Head: Same as SAC (outputs mean and log_std for Squashed Gaussian) + - QF Encoders: Multiple encoders for quantile critics + - QF Heads: Multiple heads, each outputting n_quantiles values + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + view_requirements: dict = None, + ): + """Initializes the TQCCatalog. + + Args: + observation_space: The observation space of the environment. + action_space: The action space of the environment. + model_config_dict: The model config dictionary containing + TQC-specific parameters like n_quantiles and n_critics. + view_requirements: Not used, kept for API compatibility. + """ + # Extract TQC-specific parameters before calling super().__init__ + self.n_quantiles = model_config_dict.get("n_quantiles", 25) + self.n_critics = model_config_dict.get("n_critics", 2) + + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + view_requirements=view_requirements, + ) + + # Override the QF head config to output n_quantiles instead of 1 + # For TQC, we always output n_quantiles (continuous action space) + self.qf_head_config = MLPHeadConfig( + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_qf_head_hiddens, + hidden_layer_activation=self.pi_and_qf_head_activation, + output_layer_activation="linear", + output_layer_dim=self.n_quantiles, + ) + + @OverrideToImplementCustomLogic + def build_qf_encoder(self, framework: str) -> Encoder: + """Builds a Q-function encoder for TQC. + + Same as SAC - encodes state-action pairs. + + Args: + framework: The framework to use ("torch"). + + Returns: + The encoder for the Q-network. + """ + return super().build_qf_encoder(framework=framework) + + @OverrideToImplementCustomLogic + def build_qf_head(self, framework: str) -> Model: + """Builds a Q-function head that outputs n_quantiles values. + + Args: + framework: The framework to use ("torch"). + + Returns: + The Q-function head outputting n_quantiles values. + """ + return self.qf_head_config.build(framework=framework) diff --git a/rllib/algorithms/tqc/tqc_learner.py b/rllib/algorithms/tqc/tqc_learner.py new file mode 100644 index 000000000000..4792ee21b2f3 --- /dev/null +++ b/rllib/algorithms/tqc/tqc_learner.py @@ -0,0 +1,53 @@ +""" +TQC Learner base class. + +Extends SAC Learner with quantile-specific loss computation. +""" + +from ray.rllib.algorithms.sac.sac_learner import SACLearner +from ray.rllib.core.learner.learner import Learner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import ModuleID + +# Loss keys for TQC +QF_LOSS_KEY = "qf_loss" +QF_MEAN_KEY = "qf_mean" +QF_MAX_KEY = "qf_max" +QF_MIN_KEY = "qf_min" +QUANTILES_KEY = "quantiles" +TD_ERROR_MEAN_KEY = "td_error_mean" + + +class TQCLearner(SACLearner): + """Base Learner class for TQC algorithm. + + TQC extends SAC with distributional critics using quantile regression. + The main differences are: + - Uses quantile Huber loss instead of standard Huber/MSE loss + - Computes target Q-values by sorting and truncating top quantiles + """ + + @override(Learner) + def build(self) -> None: + """Builds the TQC learner.""" + # Call parent build (handles alpha/entropy coefficient) + super().build() + + def _get_n_target_quantiles(self, module_id: ModuleID) -> int: + """Returns the number of target quantiles after truncation. + + Args: + module_id: The module ID. + + Returns: + Number of quantiles to use for target computation. + """ + config = self.config.get_config_for_module(module_id) + n_quantiles = config.n_quantiles + n_critics = config.n_critics + top_quantiles_to_drop = config.top_quantiles_to_drop_per_net + + total_quantiles = n_quantiles * n_critics + quantiles_to_drop = top_quantiles_to_drop * n_critics + + return total_quantiles - quantiles_to_drop diff --git a/rllib/examples/algorithms/tqc/humanoid_tqc.py b/rllib/examples/algorithms/tqc/humanoid_tqc.py new file mode 100644 index 000000000000..42ee29fa57cf --- /dev/null +++ b/rllib/examples/algorithms/tqc/humanoid_tqc.py @@ -0,0 +1,73 @@ +"""TQC on Humanoid-v4. + +On a single-GPU machine, with the `--num-gpus-per-learner=1` command line option, this +example should learn a episode return of >1000 in ~10h. TQC's truncated quantile critics +can help reduce overestimation bias compared to SAC. Some more hyperparameter fine +tuning, longer runs, and more scale (`--num-learners > 0` and `--num-env-runners > 0`) +should help push this up. +""" + +from torch import nn + +from ray.rllib.algorithms.tqc.tqc import TQCConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_timesteps=1000000, + default_reward=12000.0, + default_iters=2000, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values to set up `config` below. +args = parser.parse_args() + + +config = ( + TQCConfig() + .environment("Humanoid-v4") + .training( + initial_alpha=1.001, + actor_lr=0.00005, + critic_lr=0.00005, + alpha_lr=0.00005, + target_entropy="auto", + n_step=(1, 3), + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + # TQC-specific parameters + n_quantiles=25, + n_critics=2, + top_quantiles_to_drop_per_net=2, + replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 1000000, + "alpha": 0.6, + "beta": 0.4, + }, + num_steps_sampled_before_learning_starts=10000, + ) + .rl_module( + model_config=DefaultModelConfig( + fcnet_hiddens=[1024, 1024], + fcnet_activation="relu", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + head_fcnet_hiddens=[], + head_fcnet_activation=None, + head_fcnet_kernel_initializer="orthogonal_", + head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, + ) + ) + .reporting( + metrics_num_episodes_for_smoothing=5, + min_sample_timesteps_per_iteration=1000, + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/tqc/pendulum_tqc.py b/rllib/examples/algorithms/tqc/pendulum_tqc.py new file mode 100644 index 000000000000..e062312cff42 --- /dev/null +++ b/rllib/examples/algorithms/tqc/pendulum_tqc.py @@ -0,0 +1,63 @@ +from torch import nn + +from ray.rllib.algorithms.tqc.tqc import TQCConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_timesteps=20000, + default_reward=-250.0, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values to set up `config` below. +args = parser.parse_args() + +config = ( + TQCConfig() + .environment("Pendulum-v1") + .training( + initial_alpha=1.001, + # Use a smaller learning rate for the policy. + actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, + critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, + alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, + lr=None, + target_entropy="auto", + n_step=(2, 5), + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + # TQC-specific parameters + n_quantiles=25, + n_critics=2, + top_quantiles_to_drop_per_net=2, + replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 100000, + "alpha": 1.0, + "beta": 0.0, + }, + num_steps_sampled_before_learning_starts=256 * (args.num_learners or 1), + ) + .rl_module( + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + fcnet_activation="relu", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + head_fcnet_hiddens=[], + head_fcnet_activation=None, + head_fcnet_kernel_initializer="orthogonal_", + head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, + ), + ) + .reporting( + metrics_num_episodes_for_smoothing=5, + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args)