Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/connectors/common/numpy_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def __call__(
)
else:
raise ValueError(
"`NumpyToTensor`does NOT support frameworks other than torch!"
"`NumpyToTensor`does NOT support frameworks other than torch! "
f"Your current framework is {rl_module.framework}"
)
if infos is not None:
module_data[Columns.INFOS] = infos
Expand Down
23 changes: 4 additions & 19 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,10 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
"""Adds the given `other` MultiAgentEpisode to the right side of `self`.

In order for this to work, both chunks (`self` and `other`) must fit
together. This is checked by the IDs (must be identical), the time step counters
(`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
together that are split through `cut`. For sequential multi-agent environments
using slice might cause problems from hanging observation/actions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something we need to fix in the near future. Could you raise another issue on Ray OSS please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is a bug or an inherent limitation of the slice method

This is checked by the IDs (must be identical), the time step counters
(`self.env_t` must be the same as `other.env_t_started`), as well as the
observations/infos of the individual agents at the concatenation boundaries.
Also, `self.is_done` must not be True, meaning `self.is_terminated` and
`self.is_truncated` are both False.
Expand Down Expand Up @@ -842,23 +844,6 @@ def concat_episode(self, other: "MultiAgentEpisode") -> None:
# If the agent has data in both chunks, concatenate on the single-agent
# level, thereby making sure the hanging values (begin and end) match.
elif agent_id in other.agent_episodes:
# If `other` has hanging (end) values -> Add these to `self`'s agent
# SingleAgentEpisode (as a new timestep) and only then concatenate.
# Otherwise, the concatentaion would fail b/c of missing data.
if agent_id in self._hanging_actions_end:
assert agent_id in self._hanging_extra_model_outputs_end
sa_episode.add_env_step(
observation=other.agent_episodes[agent_id].get_observations(0),
infos=other.agent_episodes[agent_id].get_infos(0),
action=self._hanging_actions_end[agent_id],
reward=(
self._hanging_rewards_end[agent_id]
+ other._hanging_rewards_begin[agent_id]
),
extra_model_outputs=(
self._hanging_extra_model_outputs_end[agent_id]
),
)
sa_episode.concat_episode(other.agent_episodes[agent_id])
# Override `self`'s hanging (end) values with `other`'s hanging (end).
if agent_id in other._hanging_actions_end:
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/single_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def concat_episode(self, other: "SingleAgentEpisode") -> None:

In order for this to work, both chunks (`self` and `other`) must fit
together. This is checked by the IDs (must be identical), the time step counters
(`self.env_t` must be the same as `episode_chunk.env_t_started`), as well as the
(`self.env_t` must be the same as `other.env_t_started`), as well as the
observations/infos at the concatenation boundaries. Also, `self.is_done` must
not be True, meaning `self.is_terminated` and `self.is_truncated` are both
False.
Expand All @@ -615,7 +615,7 @@ def concat_episode(self, other: "SingleAgentEpisode") -> None:
# able to concatenate.
assert not self.is_done
# Make sure the timesteps match.
assert self.t == other.t_started
assert self.t == other.t_started, f"{self.t=}, {other.t_started=}"
# Validate `other`.
other.validate()

Expand Down
290 changes: 171 additions & 119 deletions rllib/env/tests/test_multi_agent_episode.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import unittest
from typing import Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np

import ray
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.multi_agent_env_runner import MultiAgentEnvRunner
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.utils.annotations import override
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.typing import MultiAgentDict
from ray.tune import register_env


class MultiAgentTestEnv(MultiAgentEnv):
Expand Down Expand Up @@ -125,6 +131,128 @@ def step(
return obs, reward, is_terminated, is_truncated, info


class EchoRLModule(RLModule):
"""An RLModule that returns the observation as the action (for testing)."""

framework = "torch"

@override(RLModule)
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Return the observation as the action."""
obs = batch[Columns.OBS]
# For Discrete observation space, obs is already an integer/array of integers
return {Columns.ACTIONS: obs}

@override(RLModule)
def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward(batch, **kwargs)

@override(RLModule)
def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return self._forward(batch, **kwargs)

@override(RLModule)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
raise NotImplementedError("EchoRLModule is not trainable!")


class SequentialMultiAgentEnv(MultiAgentEnv):
def __init__(
self, agent_fns: dict[str, Callable[[int], bool]], max_episode_length: int = 100
):
super().__init__()

self.agents = list(agent_fns.keys())
self.possible_agents = list(agent_fns.keys())
self.agent_fns = agent_fns

self.observation_space = gym.spaces.Dict(
{
agent: gym.spaces.Discrete(max_episode_length)
for agent in self.possible_agents
}
)
self.action_space = gym.spaces.Dict(
{
agent: gym.spaces.Discrete(max_episode_length)
for agent in self.possible_agents
}
)

self.agent_timestep = {}
self.env_timestep = 0
self.max_episode_length = max_episode_length

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[MultiAgentDict, MultiAgentDict]:
self.env_timestep = 0
self.agent_timestep = {agent: 0 for agent in self.possible_agents}

return self.get_obs(), {"env_timestep": self.env_timestep}

def step(
self, action_dict: MultiAgentDict
) -> Tuple[
MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
]:
if self.env_timestep >= self.max_episode_length:
obs = self.agent_timestep
else:
obs = self.get_obs()

# replace 1 with self.env_timestep for debugging to see what time that the observation is from
rewards = {agent: 1 for agent in obs.keys()}
terminated = {"__all__": self.env_timestep == self.max_episode_length}
truncated = {}
info = {agent: {"env_timestep": self.env_timestep} for agent in obs.keys()}

self.env_timestep += 1

return obs, rewards, terminated, truncated, info

def get_obs(self) -> dict[str, int]:
obs = {}
for agent, fn in self.agent_fns.items():
if fn(self.env_timestep):
obs[agent] = self.agent_timestep[agent]
self.agent_timestep[agent] += 1

assert obs, f"{obs=}, {self.env_timestep=}, {self.agent_timestep=}"
return obs


def create_env(config):
return SequentialMultiAgentEnv(
agent_fns={
"p_true": lambda x: True,
"p_mod_2": lambda x: x % 2 == 0,
"p_mod_5": lambda x: x % 5 == 0,
"p_within": lambda x: x in [0, 10, 15],
},
max_episode_length=20,
)


register_env("env", create_env)


config = (
PPOConfig()
.environment("env")
.env_runners(num_envs_per_env_runner=2, num_env_runners=0)
.rl_module(rl_module_spec=RLModuleSpec(module_class=EchoRLModule))
.multi_agent(
policies={"p0"},
policy_mapping_fn=lambda aid, eps, **kw: "p0",
policies_to_train=[],
)
)


# TODO (simon): Test `get_state()` and `from_state()`.
class TestMultiAgentEpisode(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -3135,125 +3263,49 @@ def test_slice(self):
check((a0.is_done, a1.is_done), (False, False))

def test_concat_episode(self):
# Generate a simple multi-agent episode.
base_episode = self._create_simple_episode(
[
{"a0": 0, "a1": 0},
{"a0": 1, "a1": 1}, # <- split here, then concat
{"a0": 2, "a1": 2},
]
)
check(len(base_episode), 2)
# Split it into two slices.
episode_1, episode_2 = base_episode[:1], base_episode[1:]
check(len(episode_1), 1)
check(len(episode_2), 1)
# Re-concat these slices.
episode_1.concat_episode(episode_2)
check(len(episode_1), 2)
check(episode_1.env_t_started, 0)
check(episode_1.env_t, 2)
a0 = episode_1.agent_episodes["a0"]
a1 = episode_1.agent_episodes["a1"]
check((len(a0), len(a1)), (2, 2))
check((a0.t_started, a1.t_started), (0, 0))
check((a0.t, a1.t), (2, 2))
check((a0.observations, a1.observations), ([0, 1, 2], [0, 1, 2]))
check((a0.actions, a1.actions), ([0, 1], [0, 1]))
check((a0.rewards, a1.rewards), ([0.0, 0.1], [0.0, 0.1]))
check((a0.is_done, a1.is_done), (False, False))
"""Tests the `concat_episode` function through sampling multiple episodes then rejoining them.

# Generate a more complex multi-agent episode.
base_episode = self._create_simple_episode(
[
{"a0": 0, "a1": 0},
{"a0": 1, "a1": 1},
{"a1": 2},
{"a1": 3},
{"a1": 4}, # <- split here, then concat
{"a0": 5, "a1": 5},
{"a0": 6}, # <- split here, then concat
{"a0": 7, "a1": 7}, # <- split here, then concat
{"a0": 8}, # <- split here, then concat
{"a1": 9},
]
)
check(len(base_episode), 9)

# Split it into two slices.
for split_ in [(4, (4, 5)), (6, (6, 3)), (7, (7, 2)), (8, (8, 1))]:
episode_1, episode_2 = base_episode[: split_[0]], base_episode[split_[0] :]
check(len(episode_1), split_[1][0])
check(len(episode_2), split_[1][1])
# Re-concat these slices.
episode_1.concat_episode(episode_2)
check(len(episode_1), 9)
check(episode_1.env_t_started, 0)
check(episode_1.env_t, 9)
a0 = episode_1.agent_episodes["a0"]
a1 = episode_1.agent_episodes["a1"]
check((len(a0), len(a1)), (5, 7))
check((a0.t_started, a1.t_started), (0, 0))
check((a0.t, a1.t), (5, 7))
check(
(a0.observations, a1.observations),
([0, 1, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 7, 9]),
)
check((a0.actions, a1.actions), ([0, 1, 5, 6, 7], [0, 1, 2, 3, 4, 5, 7]))
check(
(a0.rewards, a1.rewards),
([0, 0.1, 0.5, 0.6, 0.7], [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7]),
)
check((a0.is_done, a1.is_done), (False, False))

# Test hanging rewards.
observations = [
{"a0": 0, "a1": 0}, # 0
{"a0": 1}, # 1
{"a0": 2}, # 2 <- split here, then concat
{"a0": 3}, # 3
{"a0": 4}, # 4
]
actions = observations[:-1]
# a1 continues receiving rewards (along with a0's actions).
rewards = [
{"a0": 0.0, "a1": 0.0}, # 0
{"a0": 0.1, "a1": 1.0}, # 1
{"a0": 0.2, "a1": 2.0}, # 2
{"a0": 0.3, "a1": 3.0}, # 3
]
base_episode = MultiAgentEpisode(
observations=observations,
actions=actions,
rewards=rewards,
len_lookback_buffer=0,
)
check(len(base_episode), 4)
check(base_episode._hanging_rewards_end, {"a1": 6.0})
episode_1, episode_2 = base_episode[:2], base_episode[2:]
check(len(episode_1), 2)
check(len(episode_2), 2)
# Re-concat these slices.
episode_1.concat_episode(episode_2)
check(len(episode_1), 4)
check(episode_1.env_t_started, 0)
check(episode_1.env_t, 4)
a0 = episode_1.agent_episodes["a0"]
a1 = episode_1.agent_episodes["a1"]
check((len(a0), len(a1)), (4, 0))
check((a0.t_started, a1.t_started), (0, 0))
check((a0.t, a1.t), (4, 0))
check(
(a0.observations, a1.observations),
([0, 1, 2, 3, 4], [0]),
)
check((a0.actions, a1.actions), ([0, 1, 2, 3], []))
check(
(a0.rewards, a1.rewards),
([0, 0.1, 0.2, 0.3], []),
)
check(episode_1._hanging_rewards_end, {"a1": 6.0})
check((a0.is_done, a1.is_done), (False, False))
Then check that the concatenated episodes contain the expected data.
"""
env_runner = MultiAgentEnvRunner(config)

episodes = []
for repeat in range(10):
episodes += env_runner.sample(num_timesteps=4, random_actions=False)
unique_episode_ids = {eps.id_ for eps in episodes}

# Concat the episodes with the same episode id
for ep_id in unique_episode_ids:
eps_chunks = [ep for ep in episodes if ep.id_ == ep_id]

# Concat all chunks into the first one
combined = eps_chunks[0]
for chunk in eps_chunks[1:]:
combined.concat_episode(chunk)
# print(f' Combined episode: env_t_started={combined.env_t_started}, '
# f'env_t={combined.env_t}, is_done={combined.is_done}')

# Check the episode contents for each agent
for agent_id, sa_episode in combined.agent_episodes.items():
obs = sa_episode.get_observations()
actions = sa_episode.get_actions()
rewards = sa_episode.get_rewards()

# print(f' Agent {agent_id}: len={len(sa_episode)}, '
# f'obs={list(obs)}, rewards={list(rewards)}')

# Observations should be sequential: 0, 1, 2, 3, ...
expected_obs = list(range(len(obs)))
assert (
list(obs) == expected_obs
), f"Agent {agent_id}: expected obs {expected_obs}, got {list(obs)}"

# Actions should equal observations (EchoRLModule)
assert list(actions) == list(
obs[:-1]
), f"Agent {agent_id}: expected actions {list(obs[:-1])}, got {list(actions)}"

assert list(rewards) == [1] * len(actions)

def test_get_return(self):
# Generate an empty episode and ensure that the return is zero.
Expand Down