diff --git a/rllib/BUILD b/rllib/BUILD index 3d3d56c914c34..7556d335854f2 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1070,6 +1070,19 @@ py_test( srcs = ["algorithms/sac/tests/test_sac.py"] ) +# -------------------------------------------------------------------- +# Callback tests +# rllib/callbacks/ +# +# Tag: callbacks +# -------------------------------------------------------------------- +py_test( + name = "test_multicallback", + tags = ["team:rllib", "callbacks_dir"], + size = "small", + srcs = ["callbacks/tests/test_multicallback.py"] +) + # -------------------------------------------------------------------- # ConnectorV2 tests # rllib/connector/ diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 8ef9d6de5cf28..9c815c9d59ec4 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -39,7 +39,7 @@ from ray.rllib.offline.io_context import IOContext from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils import deep_update, merge_dicts +from ray.rllib.utils import deep_update, force_list, merge_dicts from ray.rllib.utils.annotations import ( OldAPIStack, OverrideToImplementCustomLogic_CallToSuperRecommended, @@ -2518,9 +2518,10 @@ def callbacks( # Check, whether given `callbacks` is a callable. # TODO (sven): Once the old API stack is deprecated, this can also be None # (which should then become the default value for this attribute). - if not callable(callbacks_class): + to_check = force_list(callbacks_class) + if not all(callable(c) for c in to_check): raise ValueError( - "`config.callbacks_class` must be a callable method that " + "`config.callbacks_class` must be a callable or list of callables that " "returns a subclass of DefaultCallbacks, got " f"{callbacks_class}!" ) diff --git a/rllib/callbacks/tests/__init__.py b/rllib/callbacks/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/callbacks/tests/test_multicallback.py b/rllib/callbacks/tests/test_multicallback.py new file mode 100644 index 0000000000000..715ddde695287 --- /dev/null +++ b/rllib/callbacks/tests/test_multicallback.py @@ -0,0 +1,147 @@ +import unittest +import ray +from ray.rllib.algorithms import PPOConfig +from ray.rllib.callbacks.callbacks import RLlibCallback + + +class TestMultiCallback(unittest.TestCase): + """A tests suite to test the `MultiCallback`.""" + + @classmethod + def setUp(cls) -> None: + ray.init() + + @classmethod + def tearDown(cls) -> None: + ray.shutdown() + + def test_multicallback_with_custom_callback_function(self): + """Tests if callbacks in `MultiCallback` get executed. + + This also tests, if multiple callbacks from different sources, i.e. + `callback_class` and `on_episode_step` run correctly. + """ + # Define two standard `RLlibCallback`. + class TestRLlibCallback1(RLlibCallback): + def on_episode_step( + self, + *, + episode, + env_runner=None, + metrics_logger=None, + env=None, + env_index, + rl_module=None, + worker=None, + base_env=None, + policies=None, + **kwargs + ): + + metrics_logger.log_value( + "callback_1", 1, reduce="mean", clear_on_reduce=True + ) + + class TestRLlibCallback2(RLlibCallback): + def on_episode_step( + self, + *, + episode, + env_runner=None, + metrics_logger=None, + env=None, + env_index, + rl_module=None, + worker=None, + base_env=None, + policies=None, + **kwargs + ): + + metrics_logger.log_value( + "callback_2", 2, reduce="mean", clear_on_reduce=True + ) + + # Define a custom callback function. + def custom_on_episode_step_callback( + episode, + env_runner=None, + metrics_logger=None, + env=None, + env_index=None, + rl_module=None, + worker=None, + base_env=None, + policies=None, + **kwargs + ): + + metrics_logger.log_value( + "custom_callback", 3, reduce="mean", clear_on_reduce=True + ) + + # Configure the algorithm. + config = ( + PPOConfig() + .environment("CartPole-v1") + .api_stack( + enable_env_runner_and_connector_v2=True, + enable_rl_module_and_learner=True, + ) + # Use the callbacks and callback function. + .callbacks( + callbacks_class=[TestRLlibCallback1, TestRLlibCallback2], + on_episode_step=custom_on_episode_step_callback, + ) + ) + + # Build the algorithm. At this stage, callbacks get already validated. + algo = config.build() + + # Run 10 training iteration and check, if the metrics defined in the + # callbacks made it into the results. Furthermore, check, if the values are correct. + for _ in range(10): + results = algo.train() + self.assertIn("callback_1", results["env_runners"]) + self.assertIn("callback_2", results["env_runners"]) + self.assertIn("custom_callback", results["env_runners"]) + self.assertAlmostEquals(results["env_runners"]["callback_1"], 1) + self.assertAlmostEquals(results["env_runners"]["callback_2"], 2) + self.assertAlmostEquals(results["env_runners"]["custom_callback"], 3) + + algo.stop() + + def test_multicallback_validation_error(self): + """Check, if the validation safeguard catches wrong `MultiCallback`s.""" + with self.assertRaises(ValueError): + ( + PPOConfig() + .environment("CartPole-v1") + .api_stack( + enable_env_runner_and_connector_v2=True, + enable_rl_module_and_learner=True, + ) + # This is wrong b/c it needs callables. + .callbacks(callbacks_class=["TestRLlibCallback1", "TestRLlibCallback2"]) + ) + + def test_single_callback_validation_error(self): + """Tests if the validation safeguard catches wrong `RLlibCallback`s.""" + with self.assertRaises(ValueError): + ( + PPOConfig() + .environment("CartPole-v1") + .api_stack( + enable_env_runner_and_connector_v2=True, + enable_rl_module_and_learner=True, + ) + # This is wrong b/c it needs callables. + .callbacks(callbacks_class="TestRLlibCallback") + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__]))