diff --git a/rllib/connectors/common/flatten_observations.py b/rllib/connectors/common/flatten_observations.py index 51dc8e8c038f..93049192d360 100644 --- a/rllib/connectors/common/flatten_observations.py +++ b/rllib/connectors/common/flatten_observations.py @@ -216,13 +216,87 @@ class FlattenObservations(ConnectorV2): output_batch["obs"][(episode_2.id_,)][0][0], np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]), ) + + # Multi-agent example: Use the connector with a multi-agent observation space. + # The observation space must be a Dict with agent IDs as top-level keys. + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + + # Define a per-agent observation space. + per_agent_obs_space = gym.spaces.Dict({ + "a": gym.spaces.Box(-10.0, 10.0, (), np.float32), + "b": gym.spaces.Tuple([ + gym.spaces.Discrete(2), + gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32), + ]), + "c": gym.spaces.MultiDiscrete([2, 3]), + }) + + # Create a multi-agent observation space with agent IDs as keys. + multi_agent_obs_space = gym.spaces.Dict({ + "agent_1": per_agent_obs_space, + "agent_2": per_agent_obs_space, + }) + + # Create a multi-agent episode with observations for both agents. + # Agent IDs are inferred from the keys in the observations dict. + ma_episode = MultiAgentEpisode( + observations=[ + { + "agent_1": { + "a": np.array(-10.0, np.float32), + "b": (1, np.array([[-1.0], [-1.0]], np.float32)), + "c": np.array([0, 2]), + }, + "agent_2": { + "a": np.array(10.0, np.float32), + "b": (0, np.array([[1.0], [1.0]], np.float32)), + "c": np.array([1, 1]), + }, + }, + ], + ) + + # Construct the connector for multi-agent, flattening only agent_1's observations. + # Note: If agent_ids is None (the default), all agents' observations are flattened. + connector = FlattenObservations( + multi_agent_obs_space, + act_space, + multi_agent=True, + agent_ids=["agent_1"], + ) + + # Call the connector. + output_batch = connector( + rl_module=None, + batch={}, + episodes=[ma_episode], + explore=True, + shared_data={}, + ) + + # agent_1's observation is flattened. + check( + ma_episode.agent_episodes["agent_1"].get_observations(0), + # box() disc(2). box(2, 1). multidisc(2, 3)........ + np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]), + ) + + # agent_2's observation is unchanged (not in agent_ids). + check( + ma_episode.agent_episodes["agent_2"].get_observations(0), + { + "a": np.array(10.0, np.float32), + "b": (0, np.array([[1.0], [1.0]], np.float32)), + "c": np.array([1, 1]), + }, + ) """ @override(ConnectorV2) def recompute_output_observation_space( self, - input_observation_space, - input_action_space, + input_observation_space: gym.Space, + input_action_space: gym.Space, ) -> gym.Space: self._input_obs_base_struct = get_base_struct_from_space( self.input_observation_space @@ -230,17 +304,27 @@ def recompute_output_observation_space( if self._multi_agent: spaces = {} - for agent_id, space in self._input_obs_base_struct.items(): + assert isinstance( + input_observation_space, gym.spaces.Dict + ), f"To flatten a Multi-Agent observation, it is expected that observation space is a dictionary, its actual type is {type(input_observation_space)}" + + for agent_id, space in input_observation_space.items(): # Remove keys, if necessary. # TODO (simon): Maybe allow to remove different keys for different agents. if self._keys_to_remove: + assert isinstance( + space, gym.spaces.Dict + ), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(space)}" + self._input_obs_base_struct[agent_id] = { k: v for k, v in self._input_obs_base_struct[agent_id].items() if k not in self._keys_to_remove } + if self._agent_ids and agent_id not in self._agent_ids: - spaces[agent_id] = self._input_obs_base_struct[agent_id] + # For nested spaces, we need to use the original Spaces (rather than the reduced version) + spaces[agent_id] = self.input_observation_space[agent_id] else: sample = flatten_inputs_to_1d_tensor( tree.map_structure( @@ -253,15 +337,20 @@ def recompute_output_observation_space( spaces[agent_id] = Box( float("-inf"), float("inf"), (len(sample),), np.float32 ) + return gym.spaces.Dict(spaces) else: # Remove keys, if necessary. if self._keys_to_remove: + assert isinstance( + input_observation_space, gym.spaces.Dict + ), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(input_observation_space)}" self._input_obs_base_struct = { k: v for k, v in self._input_obs_base_struct.items() if k not in self._keys_to_remove } + sample = flatten_inputs_to_1d_tensor( tree.map_structure( lambda s: s.sample(), @@ -286,13 +375,17 @@ def __init__( """Initializes a FlattenObservations instance. Args: + input_observation_space: The input observation space. For multi-agent + setups, this must be a Dict space with agent IDs as top-level keys + mapping to each agent's individual observation space. + input_action_space: The input action space. multi_agent: Whether this connector operates on multi-agent observations, in which case, the top-level of the Dict space (where agent IDs are mapped to individual agents' observation spaces) is left as-is. agent_ids: If multi_agent is True, this argument defines a collection of - AgentIDs for which to flatten. AgentIDs not in this collection are - ignored. - If None, flatten observations for all AgentIDs. None is the default. + AgentIDs for which to flatten. AgentIDs not in this collection will + have their observations passed through unchanged. + If None (the default), flatten observations for all AgentIDs. as_learner_connector: Whether this connector is part of a Learner connector pipeline, as opposed to an env-to-module pipeline. Note, this is usually only used for offline rl where the data comes