Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Adjust callback validation to account for MultiCallback. #50920

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,19 @@ py_test(
srcs = ["algorithms/sac/tests/test_sac.py"]
)

# --------------------------------------------------------------------
# Callback tests
# rllib/callbacks/
#
# Tag: callbacks
# --------------------------------------------------------------------
py_test(
name = "test_multicallback",
tags = ["team:rllib", "callbacks_dir"],
size = "small",
srcs = ["callbacks/tests/test_multicallback.py"]
)

# --------------------------------------------------------------------
# ConnectorV2 tests
# rllib/connector/
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils import deep_update, force_list, merge_dicts
from ray.rllib.utils.annotations import (
OldAPIStack,
OverrideToImplementCustomLogic_CallToSuperRecommended,
Expand Down Expand Up @@ -2518,9 +2518,10 @@ def callbacks(
# Check, whether given `callbacks` is a callable.
# TODO (sven): Once the old API stack is deprecated, this can also be None
# (which should then become the default value for this attribute).
if not callable(callbacks_class):
to_check = force_list(callbacks_class)
if not all(callable(c) for c in to_check):
raise ValueError(
"`config.callbacks_class` must be a callable method that "
"`config.callbacks_class` must be a callable or list of callables that "
"returns a subclass of DefaultCallbacks, got "
f"{callbacks_class}!"
)
Expand Down
Empty file.
147 changes: 147 additions & 0 deletions rllib/callbacks/tests/test_multicallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import unittest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Let's add this to BUILD ...

import ray
from ray.rllib.algorithms import PPOConfig
from ray.rllib.callbacks.callbacks import RLlibCallback


class TestMultiCallback(unittest.TestCase):
"""A tests suite to test the `MultiCallback`."""

@classmethod
def setUp(cls) -> None:
ray.init()

@classmethod
def tearDown(cls) -> None:
ray.shutdown()

def test_multicallback_with_custom_callback_function(self):
"""Tests if callbacks in `MultiCallback` get executed.

This also tests, if multiple callbacks from different sources, i.e.
`callback_class` and `on_episode_step` run correctly.
"""
# Define two standard `RLlibCallback`.
class TestRLlibCallback1(RLlibCallback):
def on_episode_step(
self,
*,
episode,
env_runner=None,
metrics_logger=None,
env=None,
env_index,
rl_module=None,
worker=None,
base_env=None,
policies=None,
**kwargs
):

metrics_logger.log_value(
"callback_1", 1, reduce="mean", clear_on_reduce=True
)

class TestRLlibCallback2(RLlibCallback):
def on_episode_step(
self,
*,
episode,
env_runner=None,
metrics_logger=None,
env=None,
env_index,
rl_module=None,
worker=None,
base_env=None,
policies=None,
**kwargs
):

metrics_logger.log_value(
"callback_2", 2, reduce="mean", clear_on_reduce=True
)

# Define a custom callback function.
def custom_on_episode_step_callback(
episode,
env_runner=None,
metrics_logger=None,
env=None,
env_index=None,
rl_module=None,
worker=None,
base_env=None,
policies=None,
**kwargs
):

metrics_logger.log_value(
"custom_callback", 3, reduce="mean", clear_on_reduce=True
)

# Configure the algorithm.
config = (
PPOConfig()
.environment("CartPole-v1")
.api_stack(
enable_env_runner_and_connector_v2=True,
enable_rl_module_and_learner=True,
)
# Use the callbacks and callback function.
.callbacks(
callbacks_class=[TestRLlibCallback1, TestRLlibCallback2],
on_episode_step=custom_on_episode_step_callback,
)
)

# Build the algorithm. At this stage, callbacks get already validated.
algo = config.build()

# Run 10 training iteration and check, if the metrics defined in the
# callbacks made it into the results. Furthermore, check, if the values are correct.
for _ in range(10):
results = algo.train()
self.assertIn("callback_1", results["env_runners"])
self.assertIn("callback_2", results["env_runners"])
self.assertIn("custom_callback", results["env_runners"])
self.assertAlmostEquals(results["env_runners"]["callback_1"], 1)
self.assertAlmostEquals(results["env_runners"]["callback_2"], 2)
self.assertAlmostEquals(results["env_runners"]["custom_callback"], 3)

algo.stop()

def test_multicallback_validation_error(self):
"""Check, if the validation safeguard catches wrong `MultiCallback`s."""
with self.assertRaises(ValueError):
(
PPOConfig()
.environment("CartPole-v1")
.api_stack(
enable_env_runner_and_connector_v2=True,
enable_rl_module_and_learner=True,
)
# This is wrong b/c it needs callables.
.callbacks(callbacks_class=["TestRLlibCallback1", "TestRLlibCallback2"])
)

def test_single_callback_validation_error(self):
"""Tests if the validation safeguard catches wrong `RLlibCallback`s."""
with self.assertRaises(ValueError):
(
PPOConfig()
.environment("CartPole-v1")
.api_stack(
enable_env_runner_and_connector_v2=True,
enable_rl_module_and_learner=True,
)
# This is wrong b/c it needs callables.
.callbacks(callbacks_class="TestRLlibCallback")
)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))