Skip to content

Commit 6da2636

Browse files
authored
[RLlib] New ConnectorV2 API #6: Changes in SingleAgentEpisode & SingleAgentEnvRunner. (ray-project#42296)
1 parent 3a306ef commit 6da2636

File tree

15 files changed

+1114
-630
lines changed

15 files changed

+1114
-630
lines changed

rllib/algorithms/algorithm_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def __init__(self, algo_class=None):
332332
self.enable_connectors = True
333333
self._env_to_module_connector = None
334334
self._module_to_env_connector = None
335+
self.episode_lookback_horizon = 1
335336
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
336337
# and `sample_duration_unit` (replacing batch_mode), like we do it
337338
# in the evaluation config).
@@ -1405,6 +1406,7 @@ def rollouts(
14051406
module_to_env_connector: Optional[
14061407
Callable[[EnvType, "RLModule"], "ConnectorV2"]
14071408
] = NotProvided,
1409+
episode_lookback_horizon: Optional[int] = NotProvided,
14081410
use_worker_filter_stats: Optional[bool] = NotProvided,
14091411
update_worker_filter_stats: Optional[bool] = NotProvided,
14101412
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
@@ -1455,6 +1457,13 @@ def rollouts(
14551457
module_to_env_connector: A callable taking an Env and an RLModule as input
14561458
args and returning a module-to-env ConnectorV2 (might be a pipeline)
14571459
object.
1460+
episode_lookback_horizon: The amount of data (in timesteps) to keep from the
1461+
preceeding episode chunk when a new chunk (for the same episode) is
1462+
generated to continue sampling at a later time. The larger this value,
1463+
the more an env-to-module connector will be able to look back in time
1464+
and compile RLModule input data from this information. For example, if
1465+
your custom env-to-module connector (and your custom RLModule) requires
1466+
the previous 10 rewards as inputs, you must set this to at least 10.
14581467
use_worker_filter_stats: Whether to use the workers in the WorkerSet to
14591468
update the central filters (held by the local worker). If False, stats
14601469
from the workers will not be used and discarded.
@@ -1550,6 +1559,8 @@ def rollouts(
15501559
self._env_to_module_connector = env_to_module_connector
15511560
if module_to_env_connector is not NotProvided:
15521561
self._module_to_env_connector = module_to_env_connector
1562+
if episode_lookback_horizon is not NotProvided:
1563+
self.episode_lookback_horizon = episode_lookback_horizon
15531564
if use_worker_filter_stats is not NotProvided:
15541565
self.use_worker_filter_stats = use_worker_filter_stats
15551566
if update_worker_filter_stats is not NotProvided:

rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def test_ppo_compilation_and_schedule_mixins(self):
9999

100100
num_iterations = 2
101101

102-
for fw in framework_iterator(config, frameworks=("torch", "tf2")):
102+
for fw in framework_iterator(config, frameworks=("tf2", "torch")):
103103
# TODO (Kourosh) Bring back "FrozenLake-v1"
104104
for env in ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]:
105105
print("Env={}".format(env))
106-
for lstm in [False, True]:
106+
for lstm in [True, False]:
107107
print("LSTM={}".format(lstm))
108108
config.training(model=get_model_config(fw, lstm=lstm))
109109

@@ -175,7 +175,7 @@ def test_ppo_exploration_setup(self):
175175
obs, prev_action=np.array(2), prev_reward=np.array(1.0)
176176
)
177177
)
178-
check(np.mean(actions), 1.5, atol=0.2)
178+
check(np.mean(actions), 1.5, atol=0.49)
179179
algo.stop()
180180

181181
def test_ppo_free_log_std_with_rl_modules(self):

rllib/connectors/connector_v2.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,7 @@ def __call__(
134134
"""Method for transforming input data into output data.
135135
136136
Args:
137-
rl_module: An optional RLModule object that the connector might need to know
138-
about. Note that normally, only module-to-env connectors get this
139-
information at construction time, but env-to-module and learner
140-
connectors won't (b/c they get constructed before the RLModule).
137+
rl_module: The RLModule object that the connector connects to or from.
141138
data: The input data abiding to `self.input_type` to be transformed by
142139
this connector. Transformations might either be done in-place or a new
143140
structure may be returned that matches `self.output_type`.

rllib/connectors/env_to_module/default_env_to_module.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88
from ray.rllib.core.rl_module.rl_module import RLModule
99
from ray.rllib.policy.sample_batch import SampleBatch
1010
from ray.rllib.utils.annotations import override
11+
from ray.rllib.utils.framework import try_import_tf
1112
from ray.rllib.utils.spaces.space_utils import batch
13+
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
1214
from ray.rllib.utils.typing import EpisodeType
1315
from ray.util.annotations import PublicAPI
1416

1517

18+
_, tf, _ = try_import_tf()
19+
20+
1621
@PublicAPI(stability="alpha")
1722
class DefaultEnvToModule(ConnectorV2):
1823
"""Default connector piece added by RLlib to the end of any env-to-module pipeline.
@@ -77,4 +82,13 @@ def __call__(
7782
# Note that state ins should NOT have the extra time dimension.
7883
data[STATE_IN] = batch(states)
7984

85+
# Convert data to proper tensor formats, depending on framework used by the
86+
# RLModule.
87+
# TODO (sven): Support GPU-based EnvRunners + RLModules for sampling. Right
88+
# now we assume EnvRunners are always only on the CPU.
89+
if rl_module.framework == "torch":
90+
data = convert_to_torch_tensor(data)
91+
elif rl_module.framework == "tf2":
92+
data = tree.map_structure(lambda s: tf.convert_to_tensor(s), data)
93+
8094
return data

rllib/connectors/learner/default_learner_connector.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,15 @@ def __call__(
161161
T=T,
162162
)
163163

164+
# TODO (sven): Convert data to proper tensor formats, depending on framework
165+
# used by the RLModule. We cannot do this right now as the RLModule does NOT
166+
# know its own device. Only the Learner knows the device. Also, on the
167+
# EnvRunner side, we assume that it's always the CPU (even though one could
168+
# imagine a GPU-based EnvRunner + RLModule for sampling).
169+
# if rl_module.framework == "torch":
170+
# data = convert_to_torch_tensor(data, device=??)
171+
# elif rl_module.framework == "tf2":
172+
# data =
164173
return data
165174

166175

0 commit comments

Comments
 (0)