Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 100 additions & 7 deletions rllib/connectors/common/flatten_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,31 +216,115 @@ class FlattenObservations(ConnectorV2):
output_batch["obs"][(episode_2.id_,)][0][0],
np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]),
)
# Multi-agent example: Use the connector with a multi-agent observation space.
# The observation space must be a Dict with agent IDs as top-level keys.
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
# Define a per-agent observation space.
per_agent_obs_space = gym.spaces.Dict({
"a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
"b": gym.spaces.Tuple([
gym.spaces.Discrete(2),
gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
]),
"c": gym.spaces.MultiDiscrete([2, 3]),
})
# Create a multi-agent observation space with agent IDs as keys.
multi_agent_obs_space = gym.spaces.Dict({
"agent_1": per_agent_obs_space,
"agent_2": per_agent_obs_space,
})
# Create a multi-agent episode with observations for both agents.
# Agent IDs are inferred from the keys in the observations dict.
ma_episode = MultiAgentEpisode(
observations=[
{
"agent_1": {
"a": np.array(-10.0, np.float32),
"b": (1, np.array([[-1.0], [-1.0]], np.float32)),
"c": np.array([0, 2]),
},
"agent_2": {
"a": np.array(10.0, np.float32),
"b": (0, np.array([[1.0], [1.0]], np.float32)),
"c": np.array([1, 1]),
},
},
],
)
# Construct the connector for multi-agent, flattening only agent_1's observations.
# Note: If agent_ids is None (the default), all agents' observations are flattened.
connector = FlattenObservations(
multi_agent_obs_space,
act_space,
multi_agent=True,
agent_ids=["agent_1"],
)
# Call the connector.
output_batch = connector(
rl_module=None,
batch={},
episodes=[ma_episode],
explore=True,
shared_data={},
)
# agent_1's observation is flattened.
check(
ma_episode.agent_episodes["agent_1"].get_observations(0),
# box() disc(2). box(2, 1). multidisc(2, 3)........
np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]),
)
# agent_2's observation is unchanged (not in agent_ids).
check(
ma_episode.agent_episodes["agent_2"].get_observations(0),
{
"a": np.array(10.0, np.float32),
"b": (0, np.array([[1.0], [1.0]], np.float32)),
"c": np.array([1, 1]),
},
)
"""

@override(ConnectorV2)
def recompute_output_observation_space(
self,
input_observation_space,
input_action_space,
input_observation_space: gym.Space,
input_action_space: gym.Space,
) -> gym.Space:
self._input_obs_base_struct = get_base_struct_from_space(
self.input_observation_space
)

if self._multi_agent:
spaces = {}
for agent_id, space in self._input_obs_base_struct.items():
assert isinstance(
input_observation_space, gym.spaces.Dict
), f"To flatten a Multi-Agent observation, it is expected that observation space is a dictionary, its actual type is {type(input_observation_space)}"

for agent_id, space in input_observation_space.items():
# Remove keys, if necessary.
# TODO (simon): Maybe allow to remove different keys for different agents.
if self._keys_to_remove:
assert isinstance(
space, gym.spaces.Dict
), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(space)}"

self._input_obs_base_struct[agent_id] = {
k: v
for k, v in self._input_obs_base_struct[agent_id].items()
if k not in self._keys_to_remove
}

if self._agent_ids and agent_id not in self._agent_ids:
spaces[agent_id] = self._input_obs_base_struct[agent_id]
# For nested spaces, we need to use the original Spaces (rather than the reduced version)
spaces[agent_id] = self.input_observation_space[agent_id]
else:
sample = flatten_inputs_to_1d_tensor(
tree.map_structure(
Expand All @@ -253,15 +337,20 @@ def recompute_output_observation_space(
spaces[agent_id] = Box(
float("-inf"), float("inf"), (len(sample),), np.float32
)

return gym.spaces.Dict(spaces)
else:
# Remove keys, if necessary.
if self._keys_to_remove:
assert isinstance(
input_observation_space, gym.spaces.Dict
), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(input_observation_space)}"
self._input_obs_base_struct = {
k: v
for k, v in self._input_obs_base_struct.items()
if k not in self._keys_to_remove
}

sample = flatten_inputs_to_1d_tensor(
tree.map_structure(
lambda s: s.sample(),
Expand All @@ -286,13 +375,17 @@ def __init__(
"""Initializes a FlattenObservations instance.
Args:
input_observation_space: The input observation space. For multi-agent
setups, this must be a Dict space with agent IDs as top-level keys
mapping to each agent's individual observation space.
input_action_space: The input action space.
multi_agent: Whether this connector operates on multi-agent observations,
in which case, the top-level of the Dict space (where agent IDs are
mapped to individual agents' observation spaces) is left as-is.
agent_ids: If multi_agent is True, this argument defines a collection of
AgentIDs for which to flatten. AgentIDs not in this collection are
ignored.
If None, flatten observations for all AgentIDs. None is the default.
AgentIDs for which to flatten. AgentIDs not in this collection will
have their observations passed through unchanged.
If None (the default), flatten observations for all AgentIDs.
as_learner_connector: Whether this connector is part of a Learner connector
pipeline, as opposed to an env-to-module pipeline.
Note, this is usually only used for offline rl where the data comes
Expand Down