-
Notifications
You must be signed in to change notification settings - Fork 431
[Feature] RayReplayBuffer #2835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf4801e
d206870
0f967f7
93f1fe3
9a2e0a0
2f7883e
ce2223f
a2493de
a949cec
9d13dba
8af7635
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| """ | ||
| Example use of an ever-running, fully async, distributed collector | ||
| ================================================================== | ||
| This example demonstrates how to set up and use a distributed collector | ||
| with Ray in a fully asynchronous manner. The collector continuously gathers | ||
| data from a gym environment and stores it in a replay buffer, allowing for | ||
| concurrent processing and data collection. | ||
| Key Components: | ||
| 1. **Environment Factory**: A simple function that creates instances of the | ||
| `GymEnv` environment. In this example, we use the "Pendulum-v1" environment. | ||
| 2. **Policy Definition**: A `TensorDictModule` that defines the policy network. | ||
| Here, a simple linear layer is used to map observations to actions. | ||
| 3. **Replay Buffer**: A `RayReplayBuffer` that stores collected data for later | ||
| use, such as training a reinforcement learning model. | ||
| 4. **Distributed Collector**: A `RayCollector` that manages the distributed | ||
| collection of data. It is configured with remote resources and interacts | ||
| with the environment and policy to gather data. | ||
| 5. **Asynchronous Execution**: The collector runs in the background, allowing | ||
| the main program to perform other tasks concurrently. The example includes | ||
| a loop that waits for data to be available in the buffer and samples it. | ||
| 6. **Graceful Shutdown**: The collector is shut down asynchronously, ensuring | ||
| that all resources are properly released. | ||
| This setup is useful for scenarios where you need to collect data from | ||
| multiple environments in parallel, leveraging Ray's distributed computing | ||
| capabilities to scale efficiently. | ||
| """ | ||
| import asyncio | ||
|
|
||
| from tensordict.nn import TensorDictModule | ||
| from torch import nn | ||
| from torchrl.collectors.distributed.ray import RayCollector | ||
| from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer | ||
| from torchrl.envs.libs.gym import GymEnv | ||
|
|
||
|
|
||
| async def main(): | ||
| # 1. Create environment factory | ||
| def env_maker(): | ||
| return GymEnv("Pendulum-v1", device="cpu") | ||
|
|
||
| policy = TensorDictModule( | ||
| nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] | ||
| ) | ||
|
|
||
| buffer = RayReplayBuffer() | ||
|
|
||
| # 2. Define distributed collector | ||
| remote_config = { | ||
| "num_cpus": 1, | ||
| "num_gpus": 0, | ||
| "memory": 5 * 1024**3, | ||
| "object_store_memory": 2 * 1024**3, | ||
| } | ||
| distributed_collector = RayCollector( | ||
| [env_maker], | ||
| policy, | ||
| total_frames=600, | ||
| frames_per_batch=200, | ||
| remote_configs=remote_config, | ||
| replay_buffer=buffer, | ||
| ) | ||
|
|
||
| print("start") | ||
| distributed_collector.start() | ||
|
|
||
| while True: | ||
| while not len(buffer): | ||
| print("waiting") | ||
| await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep | ||
| print("sample", buffer.sample(32)) | ||
| # break at some point | ||
| break | ||
|
|
||
| await distributed_collector.async_shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import warnings | ||
| from typing import Callable, Iterator, OrderedDict | ||
|
|
||
|
|
@@ -21,6 +22,7 @@ | |
| SyncDataCollector, | ||
| ) | ||
| from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories | ||
| from torchrl.data import ReplayBuffer | ||
| from torchrl.envs.common import EnvBase | ||
| from torchrl.envs.env_creator import EnvCreator | ||
|
|
||
|
|
@@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase): | |
| parameters being updated for a certain time even if ``update_after_each_batch`` | ||
| is turned on. | ||
| Defaults to -1 (no forced update). | ||
| replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts | ||
| but populate the buffer instead. Defaults to ``None``. | ||
|
|
||
| .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a | ||
| :class:`~torchrl.data.RayReplayBuffer` instance should be used here. | ||
|
|
||
| Examples: | ||
| >>> from torch import nn | ||
|
|
@@ -312,7 +319,9 @@ def __init__( | |
| num_collectors: int = None, | ||
| update_after_each_batch=False, | ||
| max_weight_update_interval=-1, | ||
| replay_buffer: ReplayBuffer = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is not documented |
||
| ): | ||
| self.frames_per_batch = frames_per_batch | ||
| if remote_configs is None: | ||
| remote_configs = DEFAULT_REMOTE_CLASS_CONFIG | ||
|
|
||
|
|
@@ -321,6 +330,14 @@ def __init__( | |
|
|
||
| if collector_kwargs is None: | ||
| collector_kwargs = {} | ||
| if replay_buffer is not None: | ||
| if isinstance(collector_kwargs, dict): | ||
| collector_kwargs.setdefault("replay_buffer", replay_buffer) | ||
| else: | ||
| collector_kwargs = [ | ||
| ck.setdefault("replay_buffer", replay_buffer) | ||
| for ck in collector_kwargs | ||
| ] | ||
|
|
||
| # Make sure input parameters are consistent | ||
| def check_consistency_with_num_collectors(param, param_name, num_collectors): | ||
|
|
@@ -386,7 +403,8 @@ def check_list_length_consistency(*lists): | |
| raise RuntimeError( | ||
| "ray library not found, unable to create a DistributedCollector. " | ||
| ) from RAY_ERR | ||
| ray.init(**ray_init_config) | ||
| if not ray.is_initialized(): | ||
| ray.init(**ray_init_config) | ||
| if not ray.is_initialized(): | ||
| raise RuntimeError("Ray could not be initialized.") | ||
|
|
||
|
|
@@ -400,6 +418,7 @@ def check_list_length_consistency(*lists): | |
| collector_class.as_remote = as_remote | ||
| collector_class.print_remote_collector_info = print_remote_collector_info | ||
|
|
||
| self.replay_buffer = replay_buffer | ||
| self._local_policy = policy | ||
| if isinstance(self._local_policy, nn.Module): | ||
| policy_weights = TensorDict.from_module(self._local_policy) | ||
|
|
@@ -557,7 +576,7 @@ def add_collectors( | |
| policy, | ||
| other_params, | ||
| ) | ||
| self._remote_collectors.extend([collector]) | ||
| self._remote_collectors.append(collector) | ||
|
|
||
| def local_policy(self): | ||
| """Returns local collector.""" | ||
|
|
@@ -577,17 +596,33 @@ def stop_remote_collectors(self): | |
| ) # This will interrupt any running tasks on the actor, causing them to fail immediately | ||
|
|
||
| def iterator(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should iterator raise an error for the async case? Or what happens when trying to iterate over this collector when replaybuffer was passed?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should get None outputs, and the buffer is filled each time you call next() |
||
| def proc(data): | ||
| if self.split_trajs: | ||
| data = split_trajectories(data) | ||
| if self.postproc is not None: | ||
| data = self.postproc(data) | ||
| return data | ||
|
|
||
| if self._sync: | ||
| data = self._sync_iterator() | ||
| meth = self._sync_iterator | ||
| else: | ||
| data = self._async_iterator() | ||
| meth = self._async_iterator | ||
| yield from (proc(data) for data in meth()) | ||
|
|
||
| if self.split_trajs: | ||
| data = split_trajectories(data) | ||
| if self.postproc is not None: | ||
| data = self.postproc(data) | ||
| async def _asyncio_iterator(self): | ||
| def proc(data): | ||
| if self.split_trajs: | ||
| data = split_trajectories(data) | ||
| if self.postproc is not None: | ||
| data = self.postproc(data) | ||
| return data | ||
|
|
||
| return data | ||
| if self._sync: | ||
| for d in self._sync_iterator(): | ||
| yield proc(d) | ||
| else: | ||
| for d in self._async_iterator(): | ||
| yield proc(d) | ||
|
|
||
| def _sync_iterator(self) -> Iterator[TensorDictBase]: | ||
| """Collects one data batch per remote collector in each iteration.""" | ||
|
|
@@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: | |
| ): | ||
| self.update_policy_weights_(rank) | ||
|
|
||
| self.shutdown() | ||
| if self._task is None: | ||
| self.shutdown() | ||
|
|
||
| _task = None | ||
|
|
||
| def start(self): | ||
| """Starts the RayCollector.""" | ||
| if self.replay_buffer is None: | ||
| raise RuntimeError("Replay buffer must be defined for asyncio execution.") | ||
| if self._task is None or self._task.done(): | ||
| loop = asyncio.get_event_loop() | ||
| self._task = loop.create_task(self._run_iterator_silently()) | ||
|
|
||
| async def _run_iterator_silently(self): | ||
| async for _ in self._asyncio_iterator(): | ||
| # Process each item silently | ||
| continue | ||
|
|
||
| async def async_shutdown(self): | ||
| """Finishes processes started by ray.init() during async execution.""" | ||
| if self._task is not None: | ||
| await self._task | ||
| self.stop_remote_collectors() | ||
| ray.shutdown() | ||
|
|
||
| def _async_iterator(self) -> Iterator[TensorDictBase]: | ||
| """Collects a data batch from a single remote collector in each iteration.""" | ||
|
|
@@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: | |
| ray.internal.free( | ||
| [future] | ||
| ) # should not be necessary, deleted automatically when ref count is down to 0 | ||
| self.collected_frames += out_td.numel() | ||
| self.collected_frames += self.frames_per_batch | ||
|
|
||
| yield out_td | ||
|
|
||
|
|
@@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: | |
| # object_ref=ref, | ||
| # force=False, | ||
| # ) | ||
|
|
||
| self.shutdown() | ||
| if self._task is None: | ||
| self.shutdown() | ||
|
|
||
| def update_policy_weights_(self, worker_rank=None) -> None: | ||
| """Updates the weights of the worker nodes. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps the example should be updated to use the
.from_policy_factoryin the PR above after that's merged