diff --git a/doc/source/rllib/rllib-algorithms.rst b/doc/source/rllib/rllib-algorithms.rst index 0fb9a8cff6fa..f9b61fb1072d 100644 --- a/doc/source/rllib/rllib-algorithms.rst +++ b/doc/source/rllib/rllib-algorithms.rst @@ -154,8 +154,10 @@ Soft Actor Critic (SAC) **Tuned examples:** -`Pendulum-v1 `__, -`HalfCheetah-v3 `__, +`Cartpole-v1 `__, +`Atari (Pong-v5) with Rainbow `__, +`with LSTM `__, +`Multi-Agent `__, **SAC-specific configs** (see also :ref:`generic algorithm settings `): @@ -195,8 +197,10 @@ Asynchronous Proximal Policy Optimization (APPO) **Tuned examples:** -`Pong-v5 `__ -`HalfCheetah-v4 `__ +`Atari (Pong-v5) `__ +`MuJoCo (Humanoid-v4) `__ +`Using an LSTM `__ +`Multi-Agent `__ **APPO-specific configs** (see also :ref:`generic algorithm settings `): diff --git a/rllib/BUILD.bazel b/rllib/BUILD.bazel index ff4c592b5abe..946e42214cd3 100644 --- a/rllib/BUILD.bazel +++ b/rllib/BUILD.bazel @@ -87,12 +87,29 @@ py_library( ) # -------------------------------------------------------------------- -# Algorithms learning regression tests. +# Algorithms learning regression tests (rllib/examples/algorithm/[algo-name]). # # Tag: learning_tests # -# This will test python/yaml config files -# inside rllib/examples/algorithms/[algo-name] for actual learning success. +# These tests check that the algorithm achieves above random performance within a relatively short period of time, +# not that the algorithm reaches the optimal policy. +# +# For single to multi-learner tests, the expected output should change, +# either reducing the maximum iterations or samples, or increasing the max return +# to ensure that the multi-learner is achieving something that the single shouldn’t be able to normally achieve. +# +# Compute Config +# - local (CPU) = 7 CPUs, 0 GPU: 5 Env Runners, 0 Learners on CPU, 2 Aggregator Actors per Learner on CPU +# - single (CPU) = 8 CPUs, 0 GPU: 5 Env Runner, 1 Learners on CPU, 2 Aggregator Actors per Learner on CPU +# - single (GPU) = 8 CPUs, 1 GPU: 5 Env Runner, 1 Learners on GPU, 2 Aggregator Actors per Learner on CPU +# - multi (GPU) = 16 CPUs, 2 GPUs: 10 Env Runners, 2 Learners on GPU, 2 Aggregator Actors per Learner on CPU (4 total CPUs) +# +# Legend +# - SA = Single Agent Environment +# - MA = Multi Agent Environment +# - D = Discrete actions +# - C = Continuous actions +# - LSTM = recurrent policy through lstms # -------------------------------------------------------------------- # APPO @@ -1622,238 +1639,91 @@ py_test( ], ) -# SAC -# MountainCar -py_test( - name = "learning_tests_mountaincar_sac", - size = "large", - srcs = ["examples/algorithms/sac/mountaincar_sac.py"], - args = [ - "--as-test", - ], - main = "examples/algorithms/sac/mountaincar_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_discrete", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_mountaincar_sac_gpu", - size = "large", - srcs = ["examples/algorithms/sac/mountaincar_sac.py"], - args = [ - "--as-test", - "--num-learners=1", - "--num-gpus-per-learner=1", - ], - main = "examples/algorithms/sac/mountaincar_sac.py", - tags = [ - "exclusive", - "gpu", - "learning_tests", - "learning_tests_discrete", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_mountaincar_sac_multi_cpu", - size = "large", - srcs = ["examples/algorithms/sac/mountaincar_sac.py"], - args = [ - "--as-test", - "--num-learners=2", - ], - main = "examples/algorithms/sac/mountaincar_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_discrete", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_mountaincar_sac_multi_gpu", - size = "large", - timeout = "eternal", - srcs = ["examples/algorithms/sac/mountaincar_sac.py"], - args = [ - "--as-test", - "--num-learners=2", - "--num-gpus-per-learner=1", - ], - main = "examples/algorithms/sac/mountaincar_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_discrete", - "multi_gpu", - "team:rllib", - "torch_only", - ], -) +# | SAC (14 total tests) | | Number of Learners (Device) | +# | Environment | Success | Local (CPU) | Single (CPU) | Single (GPU) | Multi (GPU) | +# |--------------------------------|---------|-------------|-----------------|--------------|-------------| +# | (SA/D/LSTM) Stateless Cartpole | 150 | ✅ | ❌ | ❌ | ❌ | +# | (MA/D) TicTacToe | -2.0 | ❌ | ✅ | ❌ | ❌ | +# | (SA/D) Atari (Pong) | 5 | ❌ | ❌ | ❌ | ✅ | +# | (SA/C) MuJoCo (Humanoid) | 200 | ❌ | ❌ | ✅ | ❌ | -# Pendulum py_test( - name = "learning_tests_pendulum_sac", + name = "learning_tests_sac_stateless_cartpole_local", size = "large", - srcs = ["examples/algorithms/sac/pendulum_sac.py"], + srcs = ["examples/algorithms/sac/stateless_cartpole_sac_with_lstm.py"], args = [ "--as-test", + "--num-cpus=7", + "--num-env-runners=5", + "--num-learners=0", + "--stop-reward=150", ], - main = "examples/algorithms/sac/pendulum_sac.py", + main = "examples/algorithms/sac/stateless_cartpole_sac_with_lstm.py", tags = [ "exclusive", "learning_tests", - "learning_tests_continuous", "team:rllib", - "torch_only", ], ) py_test( - name = "learning_tests_pendulum_sac_gpu", + name = "learning_tests_sac_tictactoe_single_cpu", size = "large", - srcs = ["examples/algorithms/sac/pendulum_sac.py"], + srcs = ["examples/algorithms/sac/tictactoe_sac.py"], args = [ "--as-test", + "--num-cpus=8", + "--num-env-runners=5", "--num-learners=1", - "--num-gpus-per-learner=1", + "--stop-reward=-2", ], - main = "examples/algorithms/sac/pendulum_sac.py", + main = "examples/algorithms/sac/tictactoe_sac.py", tags = [ "exclusive", - "gpu", "learning_tests", - "learning_tests_continuous", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_pendulum_sac_multi_cpu", - size = "large", - srcs = ["examples/algorithms/sac/pendulum_sac.py"], - args = [ - "--as-test", - "--num-learners=2", - ], - main = "examples/algorithms/sac/pendulum_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_continuous", "team:rllib", - "torch_only", ], ) py_test( - name = "learning_tests_pendulum_sac_multi_gpu", + name = "learning_tests_sac_atari_multi_gpu", size = "large", - srcs = ["examples/algorithms/sac/pendulum_sac.py"], + srcs = ["examples/algorithms/sac/atari_sac.py"], args = [ "--as-test", + "--num-cpus=16", + "--num-env-runners=10", "--num-learners=2", "--num-gpus-per-learner=1", + "--stop-reward=5", ], - main = "examples/algorithms/sac/pendulum_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_continuous", - "multi_gpu", - "team:rllib", - "torch_only", - ], -) - -# MultiAgentPendulum -py_test( - name = "learning_tests_multi_agent_pendulum_sac", - size = "large", - srcs = ["examples/algorithms/sac/multi_agent_pendulum_sac.py"], - args = [ - "--as-test", - "--num-agents=2", - "--num-cpus=4", - ], - main = "examples/algorithms/sac/multi_agent_pendulum_sac.py", + main = "examples/algorithms/sac/atari_sac.py", tags = [ "exclusive", + "gpu", "learning_tests", - "learning_tests_continuous", "team:rllib", - "torch_only", ], ) py_test( - name = "learning_tests_multi_agent_pendulum_sac_gpu", + name = "learning_tests_sac_mujoco_single_gpu", size = "large", - srcs = ["examples/algorithms/sac/multi_agent_pendulum_sac.py"], + srcs = ["examples/algorithms/sac/mujoco_sac.py"], args = [ "--as-test", - "--num-agents=2", - "--num-cpus=4", + "--num-cpus=8", + "--num-env-runners=5", "--num-learners=1", "--num-gpus-per-learner=1", + "--stop-reward=200", ], - main = "examples/algorithms/sac/multi_agent_pendulum_sac.py", - tags = [ - "exclusive", - "gpu", - "learning_tests", - "learning_tests_continuous", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_multi_agent_pendulum_sac_multi_cpu", - size = "large", - srcs = ["examples/algorithms/sac/multi_agent_pendulum_sac.py"], - args = [ - "--num-agents=2", - "--num-learners=2", - ], - main = "examples/algorithms/sac/multi_agent_pendulum_sac.py", - tags = [ - "exclusive", - "learning_tests", - "learning_tests_continuous", - "team:rllib", - "torch_only", - ], -) - -py_test( - name = "learning_tests_multi_agent_pendulum_sac_multi_gpu", - size = "large", - timeout = "eternal", - srcs = ["examples/algorithms/sac/multi_agent_pendulum_sac.py"], - args = [ - "--num-agents=2", - "--num-learners=2", - "--num-gpus-per-learner=1", - ], - main = "examples/algorithms/sac/multi_agent_pendulum_sac.py", + main = "examples/algorithms/sac/mountaincar_sac.py", tags = [ "exclusive", "learning_tests", - "learning_tests_continuous", "multi_gpu", "team:rllib", - "torch_only", ], ) diff --git a/rllib/core/rl_module/default_model_config.py b/rllib/core/rl_module/default_model_config.py index e53b852a4e7b..4a4cf7a1d309 100644 --- a/rllib/core/rl_module/default_model_config.py +++ b/rllib/core/rl_module/default_model_config.py @@ -132,7 +132,7 @@ class DefaultModelConfig: #: Activation function descriptor for the stack configured by `head_fcnet_hiddens`. #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), #: and 'linear' (or None). - head_fcnet_activation: str = "relu" + head_fcnet_activation: str | None = "relu" #: Initializer function or class descriptor for the weight/kernel matrices in the #: stack configured by `head_fcnet_hiddens`. Supported values are the initializer #: names (str), classes or functions listed by the frameworks (`torch`). See diff --git a/rllib/examples/algorithms/sac/atari_sac.py b/rllib/examples/algorithms/sac/atari_sac.py new file mode 100644 index 000000000000..04994b8e25cf --- /dev/null +++ b/rllib/examples/algorithms/sac/atari_sac.py @@ -0,0 +1,139 @@ +"""Example showing how to train SAC on Atari environments with frame stacking. + +Soft Actor-Critic (SAC) is an off-policy algorithm typically used for continuous +control, but can be adapted for discrete action spaces like Atari games. This +example demonstrates SAC with image observations using custom frame stacking +connectors and convolutional neural networks. + +This example: +- Trains on the Pong Atari environment (configurable via --env) +- Uses frame stacking (4 frames) via env-to-module and learner connectors +- Applies Atari-specific preprocessing (64x64 grayscale, reward clipping) +- Configures a CNN architecture suitable for visual observations +- Uses an entropy coefficient schedule that anneals from 0.01 to 0.0 over 3M steps + +How to run this script +---------------------- +`python atari_sac.py + +To run on a different Atari environment: +`python atari_sac.py --env=ale_py:ALE/SpaceInvaders-v5`` + +To scale up with distributed learning using multiple learners and env-runners: +`python [script file name].py --num-learners=2 --num-env-runners=8` + +To use a GPU-based learner add the number of GPUs per learners: +`python [script file name].py --num-learners=1 --num-gpus-per-learner=1` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0 --num-learners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. +By setting `--num-learners=0` and `--num-env-runners=0` will make them run locally +instead of remote Ray Actor where breakpoints aren't possible. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +Training should reach a reward of ~20 (winning most games) within 10M timesteps. +""" +import gymnasium as gym + +from ray.rllib.algorithms.sac import SACConfig +from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule +from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import register_env + +parser = add_rllib_example_script_args( + default_reward=20.0, + default_timesteps=10_000_000, +) +parser.set_defaults( + env="ale_py:ALE/Pong-v5", + num_env_runners=4, + num_envs_per_env_runner=6, + num_learners=1, +) +args = parser.parse_args() + + +def _make_env_to_module_connector(env, spaces, device): + return FrameStackingEnvToModule(num_frames=4) + + +def _make_learner_connector(input_observation_space, input_action_space): + return FrameStackingLearner(num_frames=4) + + +def _env_creator(cfg): + return wrap_atari_for_new_api_stack( + gym.make(args.env, **cfg, **{"render_mode": "rgb_array"}), + dim=64, + framestack=None, + ) + + +register_env("env", _env_creator) + + +config = ( + SACConfig() + .environment( + "env", + env_config={ + # Make analogous to old v4 + NoFrameskip. + "frameskip": 1, + "full_action_space": False, + "repeat_action_probability": 0.0, + }, + clip_rewards=True, + ) + .env_runners( + env_to_module_connector=_make_env_to_module_connector, + num_envs_per_env_runner=2, + ) + .learners( + num_aggregator_actors_per_learner=2, + ) + .training( + learner_connector=_make_learner_connector, + train_batch_size_per_learner=500, + target_network_update_freq=2, + # lr=0.0006 is very high, w/ 4 GPUs -> 0.0012 + # Might want to lower it for better stability, but it does learn well. + actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, + critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, + alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, + target_entropy="auto", + n_step=(1, 5), # 1? + tau=0.005, + replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 100000, + "alpha": 0.6, + "beta": 0.4, + }, + num_steps_sampled_before_learning_starts=10_000, + ) + .rl_module( + model_config=DefaultModelConfig( + vf_share_layers=True, + conv_filters=[(16, 4, 2), (32, 4, 2), (64, 4, 2), (128, 4, 2)], + conv_activation="relu", + head_fcnet_hiddens=[256], + ) + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/benchmark_sac_mujoco.py b/rllib/examples/algorithms/sac/benchmark_sac_mujoco.py deleted file mode 100644 index 17eee793eb57..000000000000 --- a/rllib/examples/algorithms/sac/benchmark_sac_mujoco.py +++ /dev/null @@ -1,141 +0,0 @@ -from ray import tune -from ray.rllib.algorithms.sac.sac import SACConfig -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.tune import Stopper - -# Needs the following packages to be installed on Ubuntu: -# sudo apt-get libosmesa-dev -# sudo apt-get install patchelf -# python -m pip install "gymnasium[mujoco]" -# Might need to be added to bashsrc: -# export MUJOCO_GL=osmesa" -# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco200/bin" - -# See the following links for becnhmark results of other libraries: -# Original paper: https://arxiv.org/abs/1812.05905 -# CleanRL: https://wandb.ai/cleanrl/cleanrl.benchmark/reports/Mujoco--VmlldzoxODE0NjE -# AgileRL: https://github.com/AgileRL/AgileRL?tab=readme-ov-file#benchmarks -benchmark_envs = { - "HalfCheetah-v4": { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 15000, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 3000000, - }, - "Hopper-v4": { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 3500, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 1000000, - }, - "Humanoid-v4": { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 8000, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 10000000, - }, - "Ant-v4": { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 5500, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 3000000, - }, - "Walker2d-v4": { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 6000, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 3000000, - }, -} - - -# Define a `tune.Stopper` that stops the training if the benchmark is reached -# or the maximum number of timesteps is exceeded. -class BenchmarkStopper(Stopper): - def __init__(self, benchmark_envs): - self.benchmark_envs = benchmark_envs - - def __call__(self, trial_id, result): - # Stop training if the mean reward is reached. - if ( - result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] - >= self.benchmark_envs[result["env"]][ - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}" - ] - ): - return True - # Otherwise check, if the total number of timesteps is exceeded. - elif ( - result[f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}"] - >= self.benchmark_envs[result["env"]][f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}"] - ): - return True - # Otherwise continue training. - else: - return False - - # Note, this needs to implemented b/c the parent class is abstract. - def stop_all(self): - return False - - -config = ( - SACConfig() - .environment(env=tune.grid_search(list(benchmark_envs.keys()))) - .env_runners( - rollout_fragment_length=1, - num_env_runners=0, - ) - .learners( - # Note, we have a sample/train ratio of 1:1 and a small train - # batch, so 1 learner with a single GPU should suffice. - num_learners=1, - num_gpus_per_learner=1, - ) - # TODO (simon): Adjust to new model_config_dict. - .training( - initial_alpha=1.001, - # Choose a smaller learning rate for the actor (policy). - actor_lr=3e-5, - critic_lr=3e-4, - alpha_lr=1e-4, - target_entropy="auto", - n_step=1, - tau=0.005, - train_batch_size=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "PrioritizedEpisodeReplayBuffer", - "capacity": 1000000, - "alpha": 0.6, - "beta": 0.4, - }, - num_steps_sampled_before_learning_starts=256, - model={ - "fcnet_hiddens": [256, 256], - "fcnet_activation": "relu", - "post_fcnet_hiddens": [], - "post_fcnet_activation": None, - "post_fcnet_weights_initializer": "orthogonal_", - "post_fcnet_weights_initializer_config": {"gain": 0.01}, - }, - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - min_sample_timesteps_per_iteration=1000, - ) - .evaluation( - evaluation_duration="auto", - evaluation_interval=1, - evaluation_num_env_runners=1, - evaluation_parallel_to_training=True, - evaluation_config={ - "explore": False, - }, - ) -) - -tuner = tune.Tuner( - "SAC", - param_space=config, - run_config=tune.RunConfig( - stop=BenchmarkStopper(benchmark_envs=benchmark_envs), - name="benchmark_sac_mujoco", - ), -) - -tuner.fit() diff --git a/rllib/examples/algorithms/sac/halfcheetah_sac.py b/rllib/examples/algorithms/sac/halfcheetah_sac.py deleted file mode 100644 index 7c766cbc90fa..000000000000 --- a/rllib/examples/algorithms/sac/halfcheetah_sac.py +++ /dev/null @@ -1,62 +0,0 @@ -from torch import nn - -from ray.rllib.algorithms.sac.sac import SACConfig -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.examples.utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) - -parser = add_rllib_example_script_args( - default_timesteps=1000000, - default_reward=12000.0, - default_iters=2000, -) -# Use `parser` to add your own custom command line options to this script -# and (if needed) use their values to set up `config` below. -args = parser.parse_args() - -config = ( - SACConfig() - .environment("HalfCheetah-v4") - .training( - initial_alpha=1.001, - # lr=0.0006 is very high, w/ 4 GPUs -> 0.0012 - # Might want to lower it for better stability, but it does learn well. - actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, - critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, - alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, - lr=None, - target_entropy="auto", - n_step=(1, 5), # 1? - tau=0.005, - train_batch_size_per_learner=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "PrioritizedEpisodeReplayBuffer", - "capacity": 100000, - "alpha": 0.6, - "beta": 0.4, - }, - num_steps_sampled_before_learning_starts=10000, - ) - .rl_module( - model_config=DefaultModelConfig( - fcnet_hiddens=[256, 256], - fcnet_activation="relu", - fcnet_kernel_initializer=nn.init.xavier_uniform_, - head_fcnet_hiddens=[], - head_fcnet_activation=None, - head_fcnet_kernel_initializer="orthogonal_", - head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, - ), - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - min_sample_timesteps_per_iteration=1000, - ) -) - - -if __name__ == "__main__": - run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/humanoid_sac.py b/rllib/examples/algorithms/sac/humanoid_sac.py deleted file mode 100644 index 858a95eae7ef..000000000000 --- a/rllib/examples/algorithms/sac/humanoid_sac.py +++ /dev/null @@ -1,69 +0,0 @@ -"""This is WIP. - -On a single-GPU machine, with the `--num-gpus-per-learner=1` command line option, this -example should learn a episode return of >1000 in ~10h, which is still very basic, but -does somewhat prove SAC's capabilities. Some more hyperparameter fine tuning, longer -runs, and more scale (`--num-learners > 0` and `--num-env-runners > 0`) should help push -this up. -""" - -from torch import nn - -from ray.rllib.algorithms.sac.sac import SACConfig -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.examples.utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) - -parser = add_rllib_example_script_args( - default_timesteps=1000000, - default_reward=12000.0, - default_iters=2000, -) -# Use `parser` to add your own custom command line options to this script -# and (if needed) use their values to set up `config` below. -args = parser.parse_args() - - -config = ( - SACConfig() - .environment("Humanoid-v4") - .training( - initial_alpha=1.001, - actor_lr=0.00005, - critic_lr=0.00005, - alpha_lr=0.00005, - target_entropy="auto", - n_step=(1, 3), - tau=0.005, - train_batch_size_per_learner=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "PrioritizedEpisodeReplayBuffer", - "capacity": 1000000, - "alpha": 0.6, - "beta": 0.4, - }, - num_steps_sampled_before_learning_starts=10000, - ) - .rl_module( - model_config=DefaultModelConfig( - fcnet_hiddens=[1024, 1024], - fcnet_activation="relu", - fcnet_kernel_initializer=nn.init.xavier_uniform_, - head_fcnet_hiddens=[], - head_fcnet_activation=None, - head_fcnet_kernel_initializer="orthogonal_", - head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, - ) - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - min_sample_timesteps_per_iteration=1000, - ) -) - - -if __name__ == "__main__": - run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/mountaincar_sac.py b/rllib/examples/algorithms/sac/mountaincar_sac.py deleted file mode 100644 index 75f88774750f..000000000000 --- a/rllib/examples/algorithms/sac/mountaincar_sac.py +++ /dev/null @@ -1,59 +0,0 @@ -from torch import nn - -from ray.rllib.algorithms.sac.sac import SACConfig -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.examples.utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) - -parser = add_rllib_example_script_args( - default_timesteps=20000, - default_reward=-250.0, -) -# Use `parser` to add your own custom command line options to this script -# and (if needed) use their values to set up `config` below. -args = parser.parse_args() - -config = ( - SACConfig() - .environment("MountainCar-v0") - .rl_module( - model_config=DefaultModelConfig( - fcnet_hiddens=[256, 256], - fcnet_activation="relu", - fcnet_kernel_initializer=nn.init.xavier_uniform_, - head_fcnet_hiddens=[], - head_fcnet_activation=None, - head_fcnet_kernel_initializer="orthogonal_", - head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, - ), - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - ) - .training( - initial_alpha=1.001, - # Use a smaller learning rate for the policy. - actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, - critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, - alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, - lr=None, - target_entropy="auto", - n_step=(2, 5), - tau=0.005, - train_batch_size_per_learner=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "PrioritizedEpisodeReplayBuffer", - "capacity": 100000, - "alpha": 1.0, - "beta": 0.0, - }, - num_steps_sampled_before_learning_starts=256 * (args.num_learners or 1), - ) -) - - -if __name__ == "__main__": - run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/mujoco_sac.py b/rllib/examples/algorithms/sac/mujoco_sac.py new file mode 100644 index 000000000000..5ac851c00b5e --- /dev/null +++ b/rllib/examples/algorithms/sac/mujoco_sac.py @@ -0,0 +1,102 @@ +"""Example showing how to train SAC on MuJoCo's Humanoid continuous control task. + +Soft Actor-Critic (SAC) is an off-policy maximum entropy reinforcement learning +algorithm that excels at continuous control tasks. This example demonstrates SAC +on the Humanoid-v4 MuJoCo environment with prioritized experience replay and +n-step returns. + +This example: +- Trains on the Humanoid-v4 MuJoCo locomotion environment +- Uses prioritized experience replay buffer (alpha=0.6, beta=0.4) +- Configures separate learning rates for actor, critic, and alpha (temperature) +- Applies n-step returns with random n in range [1, 5] for each sampled transition +- Uses automatic entropy tuning with target_entropy="auto" + +How to run this script +---------------------- +`python mujoco_sac.py` + +To run on a different Atari environment: +`python mujoco_sac.py --env=HalfCheetah-v4`` + +To scale up with distributed learning using multiple learners and env-runners: +`python mujoco_sac.py --num-learners=2 --num-env-runners=8` + +To use a GPU-based learner add the number of GPUs per learners: +`python mujoco_sac.py --num-learners=1 --num-gpus-per-learner=1` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0 --num-learners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. +By setting `--num-learners=0` and `--num-env-runners=0` will make them run locally +instead of remote Ray Actor where breakpoints aren't possible. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +Training should reach a reward of ~12,000 within 1M timesteps (~2000 iterations). +""" +from torch import nn + +from ray.rllib.algorithms.sac.sac import SACConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_reward=800.0, + default_timesteps=1_000_000, +) +parser.set_defaults( + env="Humanoid-v4", + num_env_runners=4, + num_envs_per_env_runner=16, + num_learners=1, +) +args = parser.parse_args() + +config = ( + SACConfig() + .environment(args.env) + .training( + initial_alpha=1.001, + # lr=0.0006 is very high, w/ 4 GPUs -> 0.0012 + # Might want to lower it for better stability, but it does learn well. + actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, + critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, + alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, + lr=None, + target_entropy="auto", + n_step=(1, 5), # 1? + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 100000, + "alpha": 0.6, + "beta": 0.4, + }, + num_steps_sampled_before_learning_starts=10_000, + ) + .rl_module( + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + fcnet_activation="relu", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + head_fcnet_hiddens=[], + head_fcnet_kernel_initializer="orthogonal_", + head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, + ), + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/multi_agent_pendulum_sac.py b/rllib/examples/algorithms/sac/multi_agent_pendulum_sac.py deleted file mode 100644 index 79b60982dd28..000000000000 --- a/rllib/examples/algorithms/sac/multi_agent_pendulum_sac.py +++ /dev/null @@ -1,85 +0,0 @@ -from torch import nn - -from ray.rllib.algorithms.sac import SACConfig -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum -from ray.rllib.examples.utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.tune.registry import register_env - -parser = add_rllib_example_script_args( - default_timesteps=500000, -) -parser.set_defaults( - num_agents=2, -) -# Use `parser` to add your own custom command line options to this script -# and (if needed) use their values to set up `config` below. -args = parser.parse_args() - -register_env("multi_agent_pendulum", lambda cfg: MultiAgentPendulum(config=cfg)) - -config = ( - SACConfig() - .environment("multi_agent_pendulum", env_config={"num_agents": args.num_agents}) - .training( - initial_alpha=1.001, - # Use a smaller learning rate for the policy. - actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, - critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, - alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, - lr=None, - target_entropy="auto", - n_step=(2, 5), - tau=0.005, - train_batch_size_per_learner=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "MultiAgentPrioritizedEpisodeReplayBuffer", - "capacity": 100000, - "alpha": 1.0, - "beta": 0.0, - }, - num_steps_sampled_before_learning_starts=256, - ) - .rl_module( - model_config=DefaultModelConfig( - fcnet_hiddens=[256, 256], - fcnet_activation="relu", - fcnet_kernel_initializer=nn.init.xavier_uniform_, - head_fcnet_hiddens=[], - head_fcnet_activation=None, - head_fcnet_kernel_initializer=nn.init.orthogonal_, - head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, - ), - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - ) -) - -if args.num_agents > 0: - config.multi_agent( - policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}", - policies={f"p{i}" for i in range(args.num_agents)}, - ) - -stop = { - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - # `episode_return_mean` is the sum of all agents/policies' returns. - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -450.0 * args.num_agents, -} - -if __name__ == "__main__": - assert ( - args.num_agents > 0 - ), "The `--num-agents` arg must be > 0 for this script to work." - - run_rllib_example_script_experiment(config, args, stop=stop) diff --git a/rllib/examples/algorithms/sac/pendulum_sac.py b/rllib/examples/algorithms/sac/pendulum_sac.py deleted file mode 100644 index 9c50a8a2838a..000000000000 --- a/rllib/examples/algorithms/sac/pendulum_sac.py +++ /dev/null @@ -1,60 +0,0 @@ -from torch import nn - -from ray.rllib.algorithms.sac.sac import SACConfig -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.examples.utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) - -parser = add_rllib_example_script_args( - default_timesteps=20000, - default_reward=-250.0, -) -# Use `parser` to add your own custom command line options to this script -# and (if needed) use their values to set up `config` below. -args = parser.parse_args() - -config = ( - SACConfig() - .environment("Pendulum-v1") - .training( - initial_alpha=1.001, - # Use a smaller learning rate for the policy. - actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, - critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, - alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, - # TODO (sven): Maybe go back to making this a dict of the sub-learning rates? - lr=None, - target_entropy="auto", - n_step=(2, 5), - tau=0.005, - train_batch_size_per_learner=256, - target_network_update_freq=1, - replay_buffer_config={ - "type": "PrioritizedEpisodeReplayBuffer", - "capacity": 100000, - "alpha": 1.0, - "beta": 0.0, - }, - num_steps_sampled_before_learning_starts=256 * (args.num_learners or 1), - ) - .rl_module( - model_config=DefaultModelConfig( - fcnet_hiddens=[256, 256], - fcnet_activation="relu", - fcnet_kernel_initializer=nn.init.xavier_uniform_, - head_fcnet_hiddens=[], - head_fcnet_activation=None, - head_fcnet_kernel_initializer="orthogonal_", - head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, - ), - ) - .reporting( - metrics_num_episodes_for_smoothing=5, - ) -) - - -if __name__ == "__main__": - run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/stateless_cartpole_sac_with_lstm.py b/rllib/examples/algorithms/sac/stateless_cartpole_sac_with_lstm.py new file mode 100644 index 000000000000..70974153cbe2 --- /dev/null +++ b/rllib/examples/algorithms/sac/stateless_cartpole_sac_with_lstm.py @@ -0,0 +1,98 @@ +"""Example showing how to train SAC on the StatelessCartPole environment. + +Soft Actor-Critic (SAC) is an off-policy maximum entropy reinforcement learning +algorithm. This example demonstrates SAC on StatelessCartPole, a modified version +of CartPole where velocity information is removed from the observations. + +This example: +- Trains on the StatelessCartPole environment (partially observable CartPole) +- Uses prioritized experience replay buffer with capacity of 100k transitions +- Configures separate learning rates for actor, critic, and alpha (temperature) +- Applies n-step returns with random n in range [2, 5] for variance reduction +- Uses automatic entropy tuning with target_entropy="auto" + +How to run this script +---------------------- +`python stateless_cartpole_sac_with_lstm.py [options]` + +To scale up with distributed learning using multiple learners and env-runners: +`python stateless_cartpole_sac_with_lstm.py --num-learners=2 --num-env-runners=8` + +To use a GPU-based learner add the number of GPUs per learners: +`python stateless_cartpole_sac_with_lstm.py --num-learners=1 --num-gpus-per-learner=1` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0 --num-learners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. +By setting `--num-learners=0` and `--num-env-runners=0` will make them run locally +instead of remote Ray Actor where breakpoints aren't possible. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +Training should reach a reward of ~350 within 500k timesteps. +""" +from torch import nn + +from ray.rllib.algorithms.sac.sac import SACConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_timesteps=500_000, + default_reward=350.0, +) +parser.set_defaults( + num_env_runners=4, + num_envs_per_env_runner=16, + num_learners=1, +) +args = parser.parse_args() + +config = ( + SACConfig() + .environment(StatelessCartPole) + .training( + initial_alpha=1.001, + # Use a smaller learning rate for the policy. + actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, + critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, + alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, + lr=None, + target_entropy="auto", + n_step=(2, 5), + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + replay_buffer_config={ + "type": "PrioritizedEpisodeReplayBuffer", + "capacity": 100000, + "alpha": 1.0, + "beta": 0.0, + }, + num_steps_sampled_before_learning_starts=256 * (args.num_learners or 1), + ) + .rl_module( + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + fcnet_activation="relu", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + head_fcnet_hiddens=[], + head_fcnet_activation=None, + head_fcnet_kernel_initializer="orthogonal_", + head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, + ), + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/algorithms/sac/tictactoe_sac.py b/rllib/examples/algorithms/sac/tictactoe_sac.py new file mode 100644 index 000000000000..9588890ab9e1 --- /dev/null +++ b/rllib/examples/algorithms/sac/tictactoe_sac.py @@ -0,0 +1,128 @@ +"""Example showing how to train SAC in a multi-agent Pendulum environment. + +This example demonstrates Soft Actor-Critic (SAC) in a multi-agent setting where +multiple independent agents each control their own pendulum. Each agent has its +own policy that learns to swing up and balance its pendulum. + +This example: +- Trains on the MultiAgentPendulum environment with configurable number of agents +- Uses a multi-agent prioritized experience replay buffer +- Configures separate policies for each agent via policy_mapping_fn +- Applies n-step returns with random n in range [2, 5] +- Uses automatic entropy tuning with target_entropy="auto" + +How to run this script +---------------------- +`python tictactoe_sac.py --num-agents=2` + +To train with more agents: +`python tictactoe_sac.py --num-agents=4` + +To scale up with distributed learning using multiple learners and env-runners: +`python tictactoe_sac.py --num-learners=2 --num-env-runners=8` + +To use a GPU-based learner add the number of GPUs per learners: +`python tictactoe_sac.py --num-learners=1 --num-gpus-per-learner=1` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0 --num-learners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. +By setting `--num-learners=0` and `--num-env-runners=0` will make them run locally +instead of remote Ray Actor where breakpoints aren't possible. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +Training should show all agents learning to swing up their pendulums within 500k +timesteps. +""" +import random + +from torch import nn + +from ray.rllib.algorithms.sac import SACConfig +from ray.rllib.core.rl_module import RLModuleSpec, MultiRLModuleSpec +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule +from ray.rllib.examples.envs.classes.multi_agent.tic_tac_toe import TicTacToe +from ray.rllib.examples.utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args(default_timesteps=500_000, default_reward=-0.5) +parser.set_defaults( + num_env_runners=4, + num_envs_per_env_runner=6, + num_learners=1, + num_agents=5, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values to set up `config` below. +args = parser.parse_args() + + +config = ( + SACConfig() + .environment(TicTacToe) + .training( + initial_alpha=1.001, + # Use a smaller learning rate for the policy. + actor_lr=2e-4 * (args.num_learners or 1) ** 0.5, + critic_lr=8e-4 * (args.num_learners or 1) ** 0.5, + alpha_lr=9e-4 * (args.num_learners or 1) ** 0.5, + lr=None, + target_entropy="auto", + n_step=(2, 5), + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + replay_buffer_config={ + "type": "MultiAgentPrioritizedEpisodeReplayBuffer", + "capacity": 100000, + "alpha": 1.0, + "beta": 0.0, + }, + num_steps_sampled_before_learning_starts=256, + ) + .multi_agent( + policies={f"p{i}" for i in range(args.num_agents)} | {"random"}, + policy_mapping_fn=lambda aid, eps, **kw: ( + random.choice([f"p{i}" for i in range(args.num_agents)] + ["random"]) + ), + policies_to_train=[f"p{i}" for i in range(args.num_agents)], + ) + .rl_module( + rl_module_spec=MultiRLModuleSpec( + rl_module_specs=( + { + f"p{i}": RLModuleSpec( + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + fcnet_activation="relu", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + head_fcnet_hiddens=[], + head_fcnet_activation=None, + head_fcnet_kernel_initializer=nn.init.orthogonal_, + head_fcnet_kernel_initializer_kwargs={"gain": 0.01}, + ), + ) + for i in range(args.num_agents) + } + | {"random": RLModuleSpec(module_class=RandomRLModule)} + ), + ), + + model_config=DefaultModelConfig( + + ), + ) +) + + +if __name__ == "__main__": + run_rllib_example_script_experiment(config, args) diff --git a/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py b/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py index ceb08422092f..5af4e4886bcc 100644 --- a/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py +++ b/rllib/examples/envs/classes/multi_agent/tic_tac_toe.py @@ -1,4 +1,6 @@ # __sphinx_doc_1_begin__ +import random + import gymnasium as gym import numpy as np @@ -8,9 +10,9 @@ class TicTacToe(MultiAgentEnv): """A two-player game in which any player tries to complete one row in a 3x3 field. - The observation space is Box(0.0, 1.0, (9,)), where each index represents a distinct - field on a 3x3 board and values of 0.0 mean the field is empty, -1.0 means - the opponend owns the field, and 1.0 means we occupy the field: + The observation space is Box(-1.0, 1.0, (9,)), where each index represents a distinct + field on a 3x3 board. From the current player's perspective: 1.0 means we occupy the + field, -1.0 means the opponent owns the field, and 0.0 means the field is empty: ---------- | 0| 1| 2| ---------- @@ -19,11 +21,11 @@ class TicTacToe(MultiAgentEnv): | 6| 7| 8| ---------- - The action space is Discrete(9) and actions landing on an already occupied field - are simply ignored (and thus useless to the player taking these actions). + The action space is Discrete(9). Actions landing on an already occupied field + result in a -1.0 penalty for the player taking the invalid action. Once a player completes a row, they receive +1.0 reward, the losing player receives - -1.0 reward. In all other cases, both players receive 0.0 reward. + -1.0 reward. A draw results in 0.0 reward for both players. """ # __sphinx_doc_1_end__ @@ -36,8 +38,8 @@ def __init__(self, config=None): self.agents = self.possible_agents = ["player1", "player2"] # Each agent observes a 9D tensor, representing the 3x3 fields of the board. - # A 0 means an empty field, a 1 represents a piece of player 1, a -1 a piece of - # player 2. + # From the current player's perspective: 1 means our piece, -1 means opponent's + # piece, 0 means empty. The board is flipped after each turn. self.observation_spaces = { "player1": gym.spaces.Box(-1.0, 1.0, (9,), np.float32), "player2": gym.spaces.Box(-1.0, 1.0, (9,), np.float32), @@ -48,32 +50,25 @@ def __init__(self, config=None): "player1": gym.spaces.Discrete(9), "player2": gym.spaces.Discrete(9), } + self.max_timesteps = 20 self.board = None self.current_player = None + self.timestep = 0 # __sphinx_doc_2_end__ # __sphinx_doc_3_begin__ def reset(self, *, seed=None, options=None): - self.board = [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ] - # Pick a random player to start the game. - self.current_player = np.random.choice(["player1", "player2"]) + self.board = [0] * 9 + + # Pick a random player to start the game and reset the current timesteps. + self.current_player = random.choice(self.agents) + self.timestep = 0 + # Return observations dict (only with the starting player, which is the one # we expect to act next). - return { - self.current_player: np.array(self.board, np.float32), - }, {} + return {self.current_player: np.array(self.board, np.float32)}, {} # __sphinx_doc_3_end__ @@ -81,27 +76,37 @@ def reset(self, *, seed=None, options=None): def step(self, action_dict): action = action_dict[self.current_player] - # Create a rewards-dict (containing the rewards of the agent that just acted). - rewards = {self.current_player: 0.0} - # Create a terminateds-dict with the special `__all__` agent ID, indicating that - # if True, the episode ends for all agents. - terminateds = {"__all__": False} - opponent = "player1" if self.current_player == "player2" else "player2" # Penalize trying to place a piece on an already occupied field. if self.board[action] != 0: - rewards[self.current_player] -= 5.0 - # Change the board according to the (valid) action taken. + rewards = {self.current_player: -1.0} + terminations = {"__all__": False} + # Truncate the agents after `max_timesteps` + elif self.timestep >= self.max_timesteps: + rewards = { + self.current_player: -0.0, + opponent: -0.0, + } + obs = { + self.current_player: np.array(self.board, np.float32), + opponent: np.array(self.board, np.float32) * -1, + } + return ( + obs, + rewards, + {"__all__": False}, + {"__all__": True}, + {}, + ) else: - self.board[action] = 1 if self.current_player == "player1" else -1 + # Change the board according to the (valid) action taken. + # For the next turn we "flip" the tokens so that the agent is always playing with the 1 vs the -1 + self.board[action] = 1 + + # After having placed a new piece, figure out whether the current player won or not. + win_val = [1, 1, 1] - # After having placed a new piece, figure out whether the current player - # won or not. - if self.current_player == "player1": - win_val = [1, 1, 1] - else: - win_val = [-1, -1, -1] if ( # Horizontal win. self.board[:3] == win_val @@ -112,33 +117,61 @@ def step(self, action_dict): or self.board[1:8:3] == win_val or self.board[2:9:3] == win_val # Diagonal win. - or self.board[::3] == win_val + or self.board[::4] == win_val or self.board[2:7:2] == win_val ): - # Final reward is +5 for victory and -5 for a loss. - rewards[self.current_player] += 5.0 - rewards[opponent] = -5.0 + # Final reward is +1 for victory and -1 for a loss. + rewards = { + self.current_player: 1.0, + opponent: -1.0, + } # Episode is done and needs to be reset for a new game. - terminateds["__all__"] = True + terminations = {"__all__": True} # The board might also be full w/o any player having won/lost. # In this case, we simply end the episode and none of the players receives # +1 or -1 reward. elif 0 not in self.board: - terminateds["__all__"] = True + rewards = { + self.current_player: 0.0, + opponent: 0.0, + } + terminations = {"__all__": True} + # Standard move with no reward + else: + rewards = {self.current_player: 0.0} + terminations = {"__all__": False} - # Flip players and return an observations dict with only the next player to - # make a move in it. + # Flip players and board so the next player sees their pieces as 1. self.current_player = opponent + self.timestep += 1 + self.board = [-x for x in self.board] return ( {self.current_player: np.array(self.board, np.float32)}, rewards, - terminateds, + terminations, {}, {}, ) + def render(self) -> str: + """Render the current board state as an ASCII grid. + + Returns: + A string representation of the board where: + - 'X' represents the current player's pieces + - 'O' represents opponent player's pieces + - ' ' represents empty fields + """ + symbols = {0: " ", 1: "X", -1: "O"} + rows = [] + for i in range(3): + row_cells = [symbols[self.board[i * 3 + j]] for j in range(3)] + rows.append(" " + " | ".join(row_cells) + " ") + separator = "-----------" + return "\n" + f"\n{separator}\n".join(rows) + "\n" + # __sphinx_doc_4_end__ diff --git a/rllib/examples/utils.py b/rllib/examples/utils.py index 87e07191fd66..b6dc057a1652 100644 --- a/rllib/examples/utils.py +++ b/rllib/examples/utils.py @@ -753,8 +753,13 @@ def run_rllib_example_script_experiment( json.dump(json_summary, f) if not test_passed: - raise ValueError( - f"`{success_metric_key}` of {success_metric_value} not reached!" - ) + if args.as_release_test: + print( + f"`{success_metric_key}` of {success_metric_value} not reached! Best value reached is {best_value}" + ) + else: + raise ValueError( + f"`{success_metric_key}` of {success_metric_value} not reached! Best value reached is {best_value}" + ) return results