Skip to content
12 changes: 12 additions & 0 deletions rllib/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,18 @@ py_test(
deps = [":conftest"],
)

# TQC
py_test(
name = "test_tqc",
size = "large",
srcs = ["algorithms/tqc/tests/test_tqc.py"],
tags = [
"algorithms_dir",
"team:rllib",
],
deps = [":conftest"],
)

# Generic testing
py_test(
name = "algorithms/tests/test_custom_resource",
Expand Down
95 changes: 95 additions & 0 deletions rllib/algorithms/tqc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# TQC (Truncated Quantile Critics)

## Overview

TQC is an extension of SAC (Soft Actor-Critic) that uses distributional reinforcement learning with quantile regression to control overestimation bias in the Q-function.

**Paper**: [Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics](https://arxiv.org/abs/2005.04269)

## Key Features

- **Distributional Critics**: Each critic network outputs multiple quantiles instead of a single Q-value
- **Multiple Critics**: Uses `n_critics` independent critic networks (default: 2)
- **Truncated Targets**: Drops the top quantiles when computing target Q-values to reduce overestimation
- **Quantile Huber Loss**: Uses quantile regression with Huber loss for critic training

## Usage

```python
from ray.rllib.algorithms.tqc import TQCConfig

config = (
TQCConfig()
.environment("Pendulum-v1")
.training(
n_quantiles=25, # Number of quantiles per critic
n_critics=2, # Number of critic networks
top_quantiles_to_drop_per_net=2, # Quantiles to drop for bias control
)
)

algo = config.build()
for _ in range(100):
result = algo.train()
print(f"Episode reward mean: {result['env_runners']['episode_reward_mean']}")
```

## Configuration

### TQC-Specific Parameters

| Parameter | Default | Description |
| ------------------------------- | ------- | ------------------------------------------------------------------ |
| `n_quantiles` | 25 | Number of quantiles for each critic network |
| `n_critics` | 2 | Number of critic networks |
| `top_quantiles_to_drop_per_net` | 2 | Number of top quantiles to drop per network when computing targets |

### Inherited from SAC

TQC inherits all SAC parameters including:

- `actor_lr`, `critic_lr`, `alpha_lr`: Learning rates
- `tau`: Target network update coefficient
- `initial_alpha`: Initial entropy coefficient
- `target_entropy`: Target entropy for automatic alpha tuning

## Algorithm Details

### Critic Update

1. Each critic outputs `n_quantiles` quantile estimates
2. For target computation:
- Collect all quantiles from all critics: `n_critics * n_quantiles` values
- Sort all quantiles
- Drop the top `top_quantiles_to_drop_per_net * n_critics` quantiles
- Use remaining quantiles as targets
3. Train critics using quantile Huber loss

### Actor Update

- Maximize expected Q-value (mean of all quantiles) minus entropy bonus
- Same as SAC but using mean of quantile estimates

### Entropy Tuning

- Same as SAC: automatically adjusts temperature parameter α

## Differences from SAC

| Aspect | SAC | TQC |
| ----------------- | -------------- | ----------------------------- |
| Critic Output | Single Q-value | `n_quantiles` quantile values |
| Number of Critics | 2 (twin_q) | `n_critics` (configurable) |
| Loss Function | Huber/MSE | Quantile Huber Loss |
| Target Q | min(Q1, Q2) | Truncated sorted quantiles |

## References

```bibtex
@article{kuznetsov2020controlling,
title={Controlling Overestimation Bias with Truncated Mixture of Continuous Distributional Quantile Critics},
author={Kuznetsov, Arsenii and Shvechikov, Pavel and Grishin, Alexander and Vetrov, Dmitry},
journal={arXiv preprint arXiv:2005.04269},
year={2020}
}
```
13 changes: 13 additions & 0 deletions rllib/algorithms/tqc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""TQC (Truncated Quantile Critics) Algorithm.

Paper: https://arxiv.org/abs/2005.04269
"""

from ray.rllib.algorithms.tqc.tqc import TQC, TQCConfig
from ray.rllib.algorithms.tqc.tqc_catalog import TQCCatalog

__all__ = [
"TQC",
"TQCConfig",
"TQCCatalog",
]
100 changes: 100 additions & 0 deletions rllib/algorithms/tqc/default_tqc_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Default TQC RLModule.

TQC uses distributional critics with quantile regression.
"""

from typing import List, Tuple

from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
override,
)
from ray.rllib.utils.typing import NetworkType
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
class DefaultTQCRLModule(RLModule, InferenceOnlyAPI, TargetNetworkAPI, QNetAPI):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The DefaultTQCTorchRLModule duplicates a significant amount of logic from this base class (e.g., in setup, make_target_networks, get_target_network_pairs) instead of inheriting from it. This leads to code duplication and makes maintenance harder.

To improve this, DefaultTQCTorchRLModule should inherit from DefaultTQCRLModule. This would require some refactoring to make the base class methods more framework-agnostic (e.g., by not initializing lists for models directly, but letting subclasses handle it).

For example, setup() in the base class could be refactored to call framework-specific model creation methods. This would allow for better code reuse and a cleaner architecture, following the pattern seen in other RLlib algorithms like PPO.

Additionally, the abstract method _qf_forward_helper seems unused in the PyTorch implementation and could potentially be removed. The signature of _qf_forward_all_critics also differs between this base class and the PyTorch implementation, which would need to be reconciled if inheritance is used.

"""RLModule for the TQC (Truncated Quantile Critics) algorithm.

TQC extends SAC by using distributional critics with quantile regression.
Each critic outputs n_quantiles values instead of a single Q-value.

Architecture:
- Policy (Actor): Same as SAC
[obs] -> [pi_encoder] -> [pi_head] -> [action_dist_inputs]

- Quantile Critics: Multiple critics, each outputting n_quantiles
[obs, action] -> [qf_encoder_i] -> [qf_head_i] -> [n_quantiles values]

- Target Quantile Critics: Target networks for each critic
[obs, action] -> [target_qf_encoder_i] -> [target_qf_head_i] -> [n_quantiles]
"""

@override(RLModule)
def setup(self):
# TQC-specific parameters from model_config
self.n_quantiles = self.model_config.get("n_quantiles", 25)
self.n_critics = self.model_config.get("n_critics", 2)
self.top_quantiles_to_drop_per_net = self.model_config.get(
"top_quantiles_to_drop_per_net", 2
)

# Total quantiles across all critics
self.quantiles_total = self.n_quantiles * self.n_critics

# Build the encoder for the policy (same as SAC)
self.pi_encoder = self.catalog.build_encoder(framework=self.framework)

if not self.inference_only or self.framework != "torch":
# Build multiple Q-function encoders and heads
self.qf_encoders = []
self.qf_heads = []

for i in range(self.n_critics):
qf_encoder = self.catalog.build_qf_encoder(framework=self.framework)
qf_head = self.catalog.build_qf_head(framework=self.framework)
self.qf_encoders.append(qf_encoder)
self.qf_heads.append(qf_head)

# Build the policy head (same as SAC)
self.pi = self.catalog.build_pi_head(framework=self.framework)

@override(TargetNetworkAPI)
def make_target_networks(self):
"""Creates target networks for all quantile critics."""
self.target_qf_encoders = []
self.target_qf_heads = []

for i in range(self.n_critics):
target_encoder = make_target_network(self.qf_encoders[i])
target_head = make_target_network(self.qf_heads[i])
self.target_qf_encoders.append(target_encoder)
self.target_qf_heads.append(target_head)

@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
"""Returns attributes not needed for inference."""
return [
"qf_encoders",
"qf_heads",
"target_qf_encoders",
"target_qf_heads",
]

@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
"""Returns pairs of (network, target_network) for updating targets."""
pairs = []
for i in range(self.n_critics):
pairs.append((self.qf_encoders[i], self.target_qf_encoders[i]))
pairs.append((self.qf_heads[i], self.target_qf_heads[i]))
return pairs

@override(RLModule)
def get_initial_state(self) -> dict:
"""TQC does not support RNNs yet."""
return {}
Empty file.
Loading