Skip to content

Commit ca29fec

Browse files
authored
[RLlib] New ConnectorV2 API #1: Some preparatory cleanups and fixes. (ray-project#41074)
1 parent 0e2a523 commit ca29fec

File tree

15 files changed

+194
-88
lines changed

15 files changed

+194
-88
lines changed

rllib/algorithms/algorithm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1932,7 +1932,10 @@ def compute_actions(
19321932
filtered_obs, filtered_state = [], []
19331933
for agent_id, ob in observations.items():
19341934
worker = self.workers.local_worker()
1935-
preprocessed = worker.preprocessors[policy_id].transform(ob)
1935+
if worker.preprocessors.get(policy_id) is not None:
1936+
preprocessed = worker.preprocessors[policy_id].transform(ob)
1937+
else:
1938+
preprocessed = ob
19361939
filtered = worker.filters[policy_id](preprocessed, update=False)
19371940
filtered_obs.append(filtered)
19381941
if state is None:

rllib/algorithms/algorithm_config.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,26 +319,37 @@ def __init__(self, algo_class=None):
319319
# If not specified, we will try to auto-detect this.
320320
self._is_atari = None
321321

322+
# TODO (sven): Rename this method into `AlgorithmConfig.sampling()`
322323
# `self.rollouts()`
323324
self.env_runner_cls = None
325+
# TODO (sven): Rename into `num_env_runner_workers`.
324326
self.num_rollout_workers = 0
325327
self.num_envs_per_worker = 1
326-
self.sample_collector = SimpleListCollector
327328
self.create_env_on_local_worker = False
328-
self.sample_async = False
329329
self.enable_connectors = True
330-
self.update_worker_filter_stats = True
331-
self.use_worker_filter_stats = True
330+
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
331+
# and `sample_duration_unit` (replacing batch_mode), like we do it
332+
# in the evaluation config).
332333
self.rollout_fragment_length = 200
334+
# TODO (sven): Rename into `sample_mode`.
333335
self.batch_mode = "truncate_episodes"
336+
# TODO (sven): Rename into `validate_env_runner_workers_after_construction`.
337+
self.validate_workers_after_construction = True
338+
self.compress_observations = False
339+
# TODO (sven): Rename into `env_runner_perf_stats_ema_coef`.
340+
self.sampler_perf_stats_ema_coef = None
341+
342+
# TODO (sven): Deprecate together with old API stack.
343+
self.sample_async = False
334344
self.remote_worker_envs = False
335345
self.remote_env_batch_wait_ms = 0
336-
self.validate_workers_after_construction = True
346+
self.enable_tf1_exec_eagerly = False
347+
self.sample_collector = SimpleListCollector
337348
self.preprocessor_pref = "deepmind"
338349
self.observation_filter = "NoFilter"
339-
self.compress_observations = False
340-
self.enable_tf1_exec_eagerly = False
341-
self.sampler_perf_stats_ema_coef = None
350+
self.update_worker_filter_stats = True
351+
self.use_worker_filter_stats = True
352+
# TODO (sven): End: deprecate.
342353

343354
# `self.training()`
344355
self.gamma = 0.99
@@ -890,7 +901,7 @@ def validate(self) -> None:
890901
error=True,
891902
)
892903

893-
# RLModule API only works with connectors and with Learner API.
904+
# New API stack (RLModule, Learner APIs) only works with connectors.
894905
if not self.enable_connectors and self._enable_new_api_stack:
895906
raise ValueError(
896907
"The new API stack (RLModule and Learner APIs) only works with "
@@ -938,6 +949,8 @@ def validate(self) -> None:
938949
"https://github.com/ray-project/ray/issues/35409 for more details."
939950
)
940951

952+
# TODO (sven): Remove this hack. We should not have env-var dependent logic
953+
# in the codebase.
941954
if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)):
942955
# Enable RLModule API and connectors if env variable is set
943956
# (to be used in unittesting)
@@ -1765,6 +1778,8 @@ def training(
17651778
dashboard. If you're seeing that the object store is filling up,
17661779
turn down the number of remote requests in flight, or enable compression
17671780
in your experiment of timesteps.
1781+
learner_class: The `Learner` class to use for (distributed) updating of the
1782+
RLModule. Only used when `_enable_new_api_stack=True`.
17681783
17691784
Returns:
17701785
This updated AlgorithmConfig object.

rllib/core/learner/learner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
ResultDict,
6464
TensorType,
6565
)
66+
from ray.util.annotations import PublicAPI
6667

6768
if TYPE_CHECKING:
6869
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
@@ -226,6 +227,7 @@ def get_hps_for_module(self, module_id: ModuleID) -> "LearnerHyperparameters":
226227
return self
227228

228229

230+
@PublicAPI(stability="alpha")
229231
class Learner:
230232
"""Base class for Learners.
231233

rllib/core/learner/learner_group.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ray.rllib.utils.numpy import convert_to_numpy
3030
from ray.train._internal.backend_executor import BackendExecutor
3131
from ray.tune.utils.file_transfer import sync_dir_between_nodes
32+
from ray.util.annotations import PublicAPI
3233

3334

3435
if TYPE_CHECKING:
@@ -58,6 +59,7 @@ def _is_module_trainable(module_id: ModuleID, batch: MultiAgentBatch) -> bool:
5859
return True
5960

6061

62+
@PublicAPI(stability="alpha")
6163
class LearnerGroup:
6264
"""Coordinator of Learners.
6365

rllib/core/models/torch/encoder.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,21 @@ def __init__(self, config: RecurrentEncoderConfig) -> None:
285285
bias=config.use_bias,
286286
)
287287

288+
self._state_in_out_spec = {
289+
"h": TensorSpec(
290+
"b, l, d",
291+
d=self.config.hidden_dim,
292+
l=self.config.num_layers,
293+
framework="torch",
294+
),
295+
"c": TensorSpec(
296+
"b, l, d",
297+
d=self.config.hidden_dim,
298+
l=self.config.num_layers,
299+
framework="torch",
300+
),
301+
}
302+
288303
@override(Model)
289304
def get_input_specs(self) -> Optional[Spec]:
290305
return SpecDict(
@@ -293,20 +308,7 @@ def get_input_specs(self) -> Optional[Spec]:
293308
SampleBatch.OBS: TensorSpec(
294309
"b, t, d", d=self.config.input_dims[0], framework="torch"
295310
),
296-
STATE_IN: {
297-
"h": TensorSpec(
298-
"b, l, h",
299-
h=self.config.hidden_dim,
300-
l=self.config.num_layers,
301-
framework="torch",
302-
),
303-
"c": TensorSpec(
304-
"b, l, h",
305-
h=self.config.hidden_dim,
306-
l=self.config.num_layers,
307-
framework="torch",
308-
),
309-
},
311+
STATE_IN: self._state_in_out_spec,
310312
}
311313
)
312314

@@ -317,20 +319,7 @@ def get_output_specs(self) -> Optional[Spec]:
317319
ENCODER_OUT: TensorSpec(
318320
"b, t, d", d=self.config.output_dims[0], framework="torch"
319321
),
320-
STATE_OUT: {
321-
"h": TensorSpec(
322-
"b, l, h",
323-
h=self.config.hidden_dim,
324-
l=self.config.num_layers,
325-
framework="torch",
326-
),
327-
"c": TensorSpec(
328-
"b, l, h",
329-
h=self.config.hidden_dim,
330-
l=self.config.num_layers,
331-
framework="torch",
332-
),
333-
},
322+
STATE_OUT: self._state_in_out_spec,
334323
}
335324
)
336325

rllib/core/rl_module/rl_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,9 @@ def get_initial_state(self) -> Any:
474474

475475
@OverrideToImplementCustomLogic
476476
def is_stateful(self) -> bool:
477-
"""Returns True if the initial state is empty.
477+
"""Returns False if the initial state is an empty dict (or None).
478478
479-
By default, RLlib assumes that the module is not recurrent if the initial
479+
By default, RLlib assumes that the module is non-recurrent if the initial
480480
state is an empty dict and recurrent otherwise.
481481
This behavior can be overridden by implementing this method.
482482
"""

rllib/env/env_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import abc
2-
32
from typing import Any, Dict, TYPE_CHECKING
43

54
from ray.rllib.utils.actor_manager import FaultAwareApply

rllib/env/multi_agent_episode.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -212,75 +212,75 @@ def get_observations(
212212

213213
return self._getattr_by_index("observations", indices, global_ts)
214214

215-
def get_actions(
215+
def get_infos(
216216
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
217217
) -> MultiAgentDict:
218-
"""Gets actions for all agents that stepped in the last timesteps.
218+
"""Gets infos for all agents that stepped in the last timesteps.
219219
220-
Note that actions are only returned for agents that stepped
220+
Note that infos are only returned for agents that stepped
221221
during the given index range.
222222
223223
Args:
224224
indices: Either a single index or a list of indices. The indices
225225
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
226-
This defines the time indices for which the actions
226+
This defines the time indices for which the infos
227227
should be returned.
228228
global_ts: Boolean that defines, if the indices should be considered
229229
environment (`True`) or agent (`False`) steps.
230230
231-
Returns: A dictionary mapping agent ids to actions (of different
231+
Returns: A dictionary mapping agent ids to infos (of different
232232
timesteps). Only for agents that have stepped (were ready) at a
233-
timestep, actions are returned (i.e. not all agent ids are
233+
timestep, infos are returned (i.e. not all agent ids are
234234
necessarily in the keys).
235235
"""
236+
return self._getattr_by_index("infos", indices, global_ts)
236237

237-
return self._getattr_by_index("actions", indices, global_ts)
238-
239-
def get_rewards(
238+
def get_actions(
240239
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
241240
) -> MultiAgentDict:
242-
"""Gets rewards for all agents that stepped in the last timesteps.
241+
"""Gets actions for all agents that stepped in the last timesteps.
243242
244-
Note that rewards are only returned for agents that stepped
243+
Note that actions are only returned for agents that stepped
245244
during the given index range.
246245
247246
Args:
248247
indices: Either a single index or a list of indices. The indices
249248
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
250-
This defines the time indices for which the rewards
249+
This defines the time indices for which the actions
251250
should be returned.
252251
global_ts: Boolean that defines, if the indices should be considered
253252
environment (`True`) or agent (`False`) steps.
254253
255-
Returns: A dictionary mapping agent ids to rewards (of different
254+
Returns: A dictionary mapping agent ids to actions (of different
256255
timesteps). Only for agents that have stepped (were ready) at a
257-
timestep, rewards are returned (i.e. not all agent ids are
256+
timestep, actions are returned (i.e. not all agent ids are
258257
necessarily in the keys).
259258
"""
260-
return self._getattr_by_index("rewards", indices, global_ts)
261259

262-
def get_infos(
260+
return self._getattr_by_index("actions", indices, global_ts)
261+
262+
def get_rewards(
263263
self, indices: Union[int, List[int]] = -1, global_ts: bool = True
264264
) -> MultiAgentDict:
265-
"""Gets infos for all agents that stepped in the last timesteps.
265+
"""Gets rewards for all agents that stepped in the last timesteps.
266266
267-
Note that infos are only returned for agents that stepped
267+
Note that rewards are only returned for agents that stepped
268268
during the given index range.
269269
270270
Args:
271271
indices: Either a single index or a list of indices. The indices
272272
can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]).
273-
This defines the time indices for which the infos
273+
This defines the time indices for which the rewards
274274
should be returned.
275275
global_ts: Boolean that defines, if the indices should be considered
276276
environment (`True`) or agent (`False`) steps.
277277
278-
Returns: A dictionary mapping agent ids to infos (of different
278+
Returns: A dictionary mapping agent ids to rewards (of different
279279
timesteps). Only for agents that have stepped (were ready) at a
280-
timestep, infos are returned (i.e. not all agent ids are
280+
timestep, rewards are returned (i.e. not all agent ids are
281281
necessarily in the keys).
282282
"""
283-
return self._getattr_by_index("infos", indices, global_ts)
283+
return self._getattr_by_index("rewards", indices, global_ts)
284284

285285
def get_extra_model_outputs(
286286
self, indices: Union[int, List[int]] = -1, global_ts: bool = True

rllib/env/single_agent_env_runner.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
2424

2525
# TODO (sven): This gives a tricky circular import that goes
26-
# deep into the library. We have to see, where to dissolve it.
26+
# deep into the library. We have to see, where to dissolve it.
2727
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
2828

2929
_, tf, _ = try_import_tf()
@@ -41,6 +41,8 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
4141
# Get the worker index on which this instance is running.
4242
self.worker_index: int = kwargs.get("worker_index")
4343

44+
# Create the vectorized gymnasium env.
45+
4446
# Register env for the local context.
4547
# Note, `gym.register` has to be called on each worker.
4648
if isinstance(self.config.env, str) and _global_registry.contains(
@@ -59,7 +61,6 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
5961
)
6062
gym.register("rllib-single-agent-env-runner-v0", entry_point=entry_point)
6163

62-
# Create the vectorized gymnasium env.
6364
# Wrap into `VectorListInfo`` wrapper to get infos as lists.
6465
self.env: gym.Wrapper = gym.wrappers.VectorListInfo(
6566
gym.vector.make(
@@ -68,31 +69,19 @@ def __init__(self, config: "AlgorithmConfig", **kwargs):
6869
asynchronous=self.config.remote_worker_envs,
6970
)
7071
)
71-
7272
self.num_envs: int = self.env.num_envs
7373
assert self.num_envs == self.config.num_envs_per_worker
7474

75-
# Create our own instance of the single-agent `RLModule` (which
75+
# Create our own instance of the (single-agent) `RLModule` (which
7676
# the needs to be weight-synched) each iteration.
77-
# TODO (sven, simon): We need to get rid here of the policy_dict,
78-
# but the 'RLModule' takes the 'policy_spec.observation_space'
79-
# from it.
80-
# Below is the non nice solution.
81-
# policy_dict, _ = self.config.get_multi_agent_setup(env=self.env)
8277
module_spec: SingleAgentRLModuleSpec = self.config.get_default_rl_module_spec()
8378
module_spec.observation_space = self.env.envs[0].observation_space
8479
# TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should
85-
# actually hold the spaces for a single env, but for boxes the
86-
# shape is (1, 1) which brings a problem with the action dists.
87-
# shape=(1,) is expected.
80+
# actually hold the spaces for a single env, but for boxes the
81+
# shape is (1, 1) which brings a problem with the action dists.
82+
# shape=(1,) is expected.
8883
module_spec.action_space = self.env.envs[0].action_space
8984
module_spec.model_config_dict = self.config.model
90-
91-
# TODO (sven): By time the `AlgorithmConfig` will get rid of `PolicyDict`
92-
# as well. Then we have to change this function parameter.
93-
# module_spec: MultiAgentRLModuleSpec = self.config.get_marl_module_spec(
94-
# policy_dict=module_dict
95-
# )
9685
self.module: RLModule = module_spec.build()
9786

9887
# This should be the default.

rllib/env/tests/test_single_agent_episode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
1010

1111
# TODO (simon): Add to the tests `info` and `extra_model_outputs`
12-
# as soon as #39732 is merged.
12+
# as soon as #39732 is merged.
1313

1414

1515
class TestEnv(gym.Env):

0 commit comments

Comments
 (0)