diff --git a/rllib/connectors/common/numpy_to_tensor.py b/rllib/connectors/common/numpy_to_tensor.py index 2e4c954dd6a7..5980a4535ba9 100644 --- a/rllib/connectors/common/numpy_to_tensor.py +++ b/rllib/connectors/common/numpy_to_tensor.py @@ -111,7 +111,8 @@ def __call__( ) else: raise ValueError( - "`NumpyToTensor`does NOT support frameworks other than torch!" + "`NumpyToTensor`does NOT support frameworks other than torch! " + f"Your current framework is {rl_module.framework}" ) if infos is not None: module_data[Columns.INFOS] = infos diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 35327fca5b92..e42e57fd9a8d 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -798,8 +798,10 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None: """Adds the given `other` MultiAgentEpisode to the right side of `self`. In order for this to work, both chunks (`self` and `other`) must fit - together. This is checked by the IDs (must be identical), the time step counters - (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the + together that are split through `cut`. For sequential multi-agent environments + using slice might cause problems from hanging observation/actions. + This is checked by the IDs (must be identical), the time step counters + (`self.env_t` must be the same as `other.env_t_started`), as well as the observations/infos of the individual agents at the concatenation boundaries. Also, `self.is_done` must not be True, meaning `self.is_terminated` and `self.is_truncated` are both False. @@ -842,23 +844,6 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None: # If the agent has data in both chunks, concatenate on the single-agent # level, thereby making sure the hanging values (begin and end) match. elif agent_id in other.agent_episodes: - # If `other` has hanging (end) values -> Add these to `self`'s agent - # SingleAgentEpisode (as a new timestep) and only then concatenate. - # Otherwise, the concatentaion would fail b/c of missing data. - if agent_id in self._hanging_actions_end: - assert agent_id in self._hanging_extra_model_outputs_end - sa_episode.add_env_step( - observation=other.agent_episodes[agent_id].get_observations(0), - infos=other.agent_episodes[agent_id].get_infos(0), - action=self._hanging_actions_end[agent_id], - reward=( - self._hanging_rewards_end[agent_id] - + other._hanging_rewards_begin[agent_id] - ), - extra_model_outputs=( - self._hanging_extra_model_outputs_end[agent_id] - ), - ) sa_episode.concat_episode(other.agent_episodes[agent_id]) # Override `self`'s hanging (end) values with `other`'s hanging (end). if agent_id in other._hanging_actions_end: diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py index dffe8affa04d..84945f1c1c24 100644 --- a/rllib/env/single_agent_episode.py +++ b/rllib/env/single_agent_episode.py @@ -598,7 +598,7 @@ def concat_episode(self, other: "SingleAgentEpisode") -> None: In order for this to work, both chunks (`self` and `other`) must fit together. This is checked by the IDs (must be identical), the time step counters - (`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the + (`self.env_t` must be the same as `other.env_t_started`), as well as the observations/infos at the concatenation boundaries. Also, `self.is_done` must not be True, meaning `self.is_terminated` and `self.is_truncated` are both False. @@ -615,7 +615,7 @@ def concat_episode(self, other: "SingleAgentEpisode") -> None: # able to concatenate. assert not self.is_done # Make sure the timesteps match. - assert self.t == other.t_started + assert self.t == other.t_started, f"{self.t=}, {other.t_started=}" # Validate `other`. other.validate() diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index 4a2fc83b6dd6..3a132a87ffa4 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -1,14 +1,20 @@ import unittest -from typing import Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import gymnasium as gym import numpy as np import ray +from ray.rllib.algorithms.ppo.ppo import PPOConfig +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.multi_agent_env_runner import MultiAgentEnvRunner from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.utils.annotations import override from ray.rllib.utils.test_utils import check from ray.rllib.utils.typing import MultiAgentDict +from ray.tune import register_env class MultiAgentTestEnv(MultiAgentEnv): @@ -125,6 +131,128 @@ def step( return obs, reward, is_terminated, is_truncated, info +class EchoRLModule(RLModule): + """An RLModule that returns the observation as the action (for testing).""" + + framework = "torch" + + @override(RLModule) + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Return the observation as the action.""" + obs = batch[Columns.OBS] + # For Discrete observation space, obs is already an integer/array of integers + return {Columns.ACTIONS: obs} + + @override(RLModule) + def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return self._forward(batch, **kwargs) + + @override(RLModule) + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + return self._forward(batch, **kwargs) + + @override(RLModule) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + raise NotImplementedError("EchoRLModule is not trainable!") + + +class SequentialMultiAgentEnv(MultiAgentEnv): + def __init__( + self, agent_fns: dict[str, Callable[[int], bool]], max_episode_length: int = 100 + ): + super().__init__() + + self.agents = list(agent_fns.keys()) + self.possible_agents = list(agent_fns.keys()) + self.agent_fns = agent_fns + + self.observation_space = gym.spaces.Dict( + { + agent: gym.spaces.Discrete(max_episode_length) + for agent in self.possible_agents + } + ) + self.action_space = gym.spaces.Dict( + { + agent: gym.spaces.Discrete(max_episode_length) + for agent in self.possible_agents + } + ) + + self.agent_timestep = {} + self.env_timestep = 0 + self.max_episode_length = max_episode_length + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[MultiAgentDict, MultiAgentDict]: + self.env_timestep = 0 + self.agent_timestep = {agent: 0 for agent in self.possible_agents} + + return self.get_obs(), {"env_timestep": self.env_timestep} + + def step( + self, action_dict: MultiAgentDict + ) -> Tuple[ + MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict + ]: + if self.env_timestep >= self.max_episode_length: + obs = self.agent_timestep + else: + obs = self.get_obs() + + # replace 1 with self.env_timestep for debugging to see what time that the observation is from + rewards = {agent: 1 for agent in obs.keys()} + terminated = {"__all__": self.env_timestep == self.max_episode_length} + truncated = {} + info = {agent: {"env_timestep": self.env_timestep} for agent in obs.keys()} + + self.env_timestep += 1 + + return obs, rewards, terminated, truncated, info + + def get_obs(self) -> dict[str, int]: + obs = {} + for agent, fn in self.agent_fns.items(): + if fn(self.env_timestep): + obs[agent] = self.agent_timestep[agent] + self.agent_timestep[agent] += 1 + + assert obs, f"{obs=}, {self.env_timestep=}, {self.agent_timestep=}" + return obs + + +def create_env(config): + return SequentialMultiAgentEnv( + agent_fns={ + "p_true": lambda x: True, + "p_mod_2": lambda x: x % 2 == 0, + "p_mod_5": lambda x: x % 5 == 0, + "p_within": lambda x: x in [0, 10, 15], + }, + max_episode_length=20, + ) + + +register_env("env", create_env) + + +config = ( + PPOConfig() + .environment("env") + .env_runners(num_envs_per_env_runner=2, num_env_runners=0) + .rl_module(rl_module_spec=RLModuleSpec(module_class=EchoRLModule)) + .multi_agent( + policies={"p0"}, + policy_mapping_fn=lambda aid, eps, **kw: "p0", + policies_to_train=[], + ) +) + + # TODO (simon): Test `get_state()` and `from_state()`. class TestMultiAgentEpisode(unittest.TestCase): @classmethod @@ -3135,125 +3263,49 @@ def test_slice(self): check((a0.is_done, a1.is_done), (False, False)) def test_concat_episode(self): - # Generate a simple multi-agent episode. - base_episode = self._create_simple_episode( - [ - {"a0": 0, "a1": 0}, - {"a0": 1, "a1": 1}, # <- split here, then concat - {"a0": 2, "a1": 2}, - ] - ) - check(len(base_episode), 2) - # Split it into two slices. - episode_1, episode_2 = base_episode[:1], base_episode[1:] - check(len(episode_1), 1) - check(len(episode_2), 1) - # Re-concat these slices. - episode_1.concat_episode(episode_2) - check(len(episode_1), 2) - check(episode_1.env_t_started, 0) - check(episode_1.env_t, 2) - a0 = episode_1.agent_episodes["a0"] - a1 = episode_1.agent_episodes["a1"] - check((len(a0), len(a1)), (2, 2)) - check((a0.t_started, a1.t_started), (0, 0)) - check((a0.t, a1.t), (2, 2)) - check((a0.observations, a1.observations), ([0, 1, 2], [0, 1, 2])) - check((a0.actions, a1.actions), ([0, 1], [0, 1])) - check((a0.rewards, a1.rewards), ([0.0, 0.1], [0.0, 0.1])) - check((a0.is_done, a1.is_done), (False, False)) + """Tests the `concat_episode` function through sampling multiple episodes then rejoining them. - # Generate a more complex multi-agent episode. - base_episode = self._create_simple_episode( - [ - {"a0": 0, "a1": 0}, - {"a0": 1, "a1": 1}, - {"a1": 2}, - {"a1": 3}, - {"a1": 4}, # <- split here, then concat - {"a0": 5, "a1": 5}, - {"a0": 6}, # <- split here, then concat - {"a0": 7, "a1": 7}, # <- split here, then concat - {"a0": 8}, # <- split here, then concat - {"a1": 9}, - ] - ) - check(len(base_episode), 9) - - # Split it into two slices. - for split_ in [(4, (4, 5)), (6, (6, 3)), (7, (7, 2)), (8, (8, 1))]: - episode_1, episode_2 = base_episode[: split_[0]], base_episode[split_[0] :] - check(len(episode_1), split_[1][0]) - check(len(episode_2), split_[1][1]) - # Re-concat these slices. - episode_1.concat_episode(episode_2) - check(len(episode_1), 9) - check(episode_1.env_t_started, 0) - check(episode_1.env_t, 9) - a0 = episode_1.agent_episodes["a0"] - a1 = episode_1.agent_episodes["a1"] - check((len(a0), len(a1)), (5, 7)) - check((a0.t_started, a1.t_started), (0, 0)) - check((a0.t, a1.t), (5, 7)) - check( - (a0.observations, a1.observations), - ([0, 1, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 7, 9]), - ) - check((a0.actions, a1.actions), ([0, 1, 5, 6, 7], [0, 1, 2, 3, 4, 5, 7])) - check( - (a0.rewards, a1.rewards), - ([0, 0.1, 0.5, 0.6, 0.7], [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7]), - ) - check((a0.is_done, a1.is_done), (False, False)) - - # Test hanging rewards. - observations = [ - {"a0": 0, "a1": 0}, # 0 - {"a0": 1}, # 1 - {"a0": 2}, # 2 <- split here, then concat - {"a0": 3}, # 3 - {"a0": 4}, # 4 - ] - actions = observations[:-1] - # a1 continues receiving rewards (along with a0's actions). - rewards = [ - {"a0": 0.0, "a1": 0.0}, # 0 - {"a0": 0.1, "a1": 1.0}, # 1 - {"a0": 0.2, "a1": 2.0}, # 2 - {"a0": 0.3, "a1": 3.0}, # 3 - ] - base_episode = MultiAgentEpisode( - observations=observations, - actions=actions, - rewards=rewards, - len_lookback_buffer=0, - ) - check(len(base_episode), 4) - check(base_episode._hanging_rewards_end, {"a1": 6.0}) - episode_1, episode_2 = base_episode[:2], base_episode[2:] - check(len(episode_1), 2) - check(len(episode_2), 2) - # Re-concat these slices. - episode_1.concat_episode(episode_2) - check(len(episode_1), 4) - check(episode_1.env_t_started, 0) - check(episode_1.env_t, 4) - a0 = episode_1.agent_episodes["a0"] - a1 = episode_1.agent_episodes["a1"] - check((len(a0), len(a1)), (4, 0)) - check((a0.t_started, a1.t_started), (0, 0)) - check((a0.t, a1.t), (4, 0)) - check( - (a0.observations, a1.observations), - ([0, 1, 2, 3, 4], [0]), - ) - check((a0.actions, a1.actions), ([0, 1, 2, 3], [])) - check( - (a0.rewards, a1.rewards), - ([0, 0.1, 0.2, 0.3], []), - ) - check(episode_1._hanging_rewards_end, {"a1": 6.0}) - check((a0.is_done, a1.is_done), (False, False)) + Then check that the concatenated episodes contain the expected data. + """ + env_runner = MultiAgentEnvRunner(config) + + episodes = [] + for repeat in range(10): + episodes += env_runner.sample(num_timesteps=4, random_actions=False) + unique_episode_ids = {eps.id_ for eps in episodes} + + # Concat the episodes with the same episode id + for ep_id in unique_episode_ids: + eps_chunks = [ep for ep in episodes if ep.id_ == ep_id] + + # Concat all chunks into the first one + combined = eps_chunks[0] + for chunk in eps_chunks[1:]: + combined.concat_episode(chunk) + # print(f' Combined episode: env_t_started={combined.env_t_started}, ' + # f'env_t={combined.env_t}, is_done={combined.is_done}') + + # Check the episode contents for each agent + for agent_id, sa_episode in combined.agent_episodes.items(): + obs = sa_episode.get_observations() + actions = sa_episode.get_actions() + rewards = sa_episode.get_rewards() + + # print(f' Agent {agent_id}: len={len(sa_episode)}, ' + # f'obs={list(obs)}, rewards={list(rewards)}') + + # Observations should be sequential: 0, 1, 2, 3, ... + expected_obs = list(range(len(obs))) + assert ( + list(obs) == expected_obs + ), f"Agent {agent_id}: expected obs {expected_obs}, got {list(obs)}" + + # Actions should equal observations (EchoRLModule) + assert list(actions) == list( + obs[:-1] + ), f"Agent {agent_id}: expected actions {list(obs[:-1])}, got {list(actions)}" + + assert list(rewards) == [1] * len(actions) def test_get_return(self): # Generate an empty episode and ensure that the return is zero.