Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ widely used replay buffers:
PrioritizedReplayBuffer
TensorDictReplayBuffer
TensorDictPrioritizedReplayBuffer
RayReplayBuffer
RemoteTensorDictReplayBuffer

Composable Replay Buffers
-------------------------
Expand Down
82 changes: 82 additions & 0 deletions examples/distributed/collectors/multi_nodes/ray_buffer_infra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Example use of an ever-running, fully async, distributed collector
==================================================================
This example demonstrates how to set up and use a distributed collector
with Ray in a fully asynchronous manner. The collector continuously gathers
data from a gym environment and stores it in a replay buffer, allowing for
concurrent processing and data collection.
Key Components:
1. **Environment Factory**: A simple function that creates instances of the
`GymEnv` environment. In this example, we use the "Pendulum-v1" environment.
2. **Policy Definition**: A `TensorDictModule` that defines the policy network.
Here, a simple linear layer is used to map observations to actions.
3. **Replay Buffer**: A `RayReplayBuffer` that stores collected data for later
use, such as training a reinforcement learning model.
4. **Distributed Collector**: A `RayCollector` that manages the distributed
collection of data. It is configured with remote resources and interacts
with the environment and policy to gather data.
5. **Asynchronous Execution**: The collector runs in the background, allowing
the main program to perform other tasks concurrently. The example includes
a loop that waits for data to be available in the buffer and samples it.
6. **Graceful Shutdown**: The collector is shut down asynchronously, ensuring
that all resources are properly released.
This setup is useful for scenarios where you need to collect data from
multiple environments in parallel, leveraging Ray's distributed computing
capabilities to scale efficiently.
"""
import asyncio

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.collectors.distributed.ray import RayCollector
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
from torchrl.envs.libs.gym import GymEnv


async def main():
# 1. Create environment factory
def env_maker():
return GymEnv("Pendulum-v1", device="cpu")

policy = TensorDictModule(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps the example should be updated to use the .from_policy_factory in the PR above after that's merged

nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
)

buffer = RayReplayBuffer()

# 2. Define distributed collector
remote_config = {
"num_cpus": 1,
"num_gpus": 0,
"memory": 5 * 1024**3,
"object_store_memory": 2 * 1024**3,
}
distributed_collector = RayCollector(
[env_maker],
policy,
total_frames=600,
frames_per_batch=200,
remote_configs=remote_config,
replay_buffer=buffer,
)

print("start")
distributed_collector.start()

while True:
while not len(buffer):
print("waiting")
await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep
print("sample", buffer.sample(32))
# break at some point
break

await distributed_collector.async_shutdown()


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion examples/distributed/collectors/multi_nodes/ray_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def env_maker():
# 2. Define distributed collector
remote_config = {
"num_cpus": 1,
"num_gpus": 0.2,
"num_gpus": 0,
"memory": 5 * 1024**3,
"object_store_memory": 2 * 1024**3,
}
Expand All @@ -36,6 +36,7 @@ def env_maker():
policy,
total_frames=10000,
frames_per_batch=200,
remote_configs=remote_config,
)

# Sample batches until reaching total_frames
Expand Down
41 changes: 41 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@
import os
import sys
import time
from functools import partial

import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from torchrl._utils import logger as torchrl_logger
from torchrl.data import (
LazyTensorStorage,
RandomSampler,
RayReplayBuffer,
RoundRobinWriter,
SamplerWithoutReplacement,
)

try:
import ray
Expand Down Expand Up @@ -435,6 +444,15 @@ class TestRayCollector(DistributedCollectorBase):
to avoid potential deadlocks when combining Ray and multiprocessing.
"""

@pytest.fixture(autouse=True, scope="class")
def start_ray(self):
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG

ray.init(**DEFAULT_RAY_INIT_CONFIG)

yield
ray.shutdown()

@classmethod
def distributed_class(cls) -> type:
return RayCollector
Expand Down Expand Up @@ -552,6 +570,29 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
collector.shutdown()
assert total == total_frames

@pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)])
@pytest.mark.parametrize(
"sampler", [None, partial(RandomSampler), SamplerWithoutReplacement]
)
@pytest.mark.parametrize("writer", [None, partial(RoundRobinWriter)])
def test_ray_replaybuffer(self, storage, sampler, writer):
kwargs = self.distributed_kwargs()
kwargs["remote_config"] = kwargs.pop("remote_configs")
rb = RayReplayBuffer(
storage=storage,
sampler=sampler,
writer=writer,
batch_size=32,
**kwargs,
)
td = TensorDict(a=torch.arange(100, 200), batch_size=[100])
index = rb.extend(td)
assert (index == torch.arange(100)).all()
for _ in range(10):
sample = rb.sample()
if sampler is SamplerWithoutReplacement:
assert sample["a"].unique().numel() == sample.numel()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
9 changes: 5 additions & 4 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def next(self):
self._iterator = iter(self)
out = next(self._iterator)
# if any, we don't want the device ref to be passed in distributed settings
out.clear_device_()
if out is not None:
out.clear_device_()
return out
except StopIteration:
return None
Expand Down Expand Up @@ -432,7 +433,7 @@ class SyncDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
but populate the buffer instead. Defaults to ``None``.
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
Expand Down Expand Up @@ -1430,7 +1431,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
def __repr__(self) -> str:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ")
td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ")
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
Expand Down Expand Up @@ -1586,7 +1587,7 @@ class _MultiDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
but populate the buffer instead. Defaults to ``None``.
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
Expand Down
84 changes: 71 additions & 13 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
import warnings
from typing import Callable, Iterator, OrderedDict

Expand All @@ -21,6 +22,7 @@
SyncDataCollector,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
from torchrl.data import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator

Expand Down Expand Up @@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase):
parameters being updated for a certain time even if ``update_after_each_batch``
is turned on.
Defaults to -1 (no forced update).
replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts
but populate the buffer instead. Defaults to ``None``.

.. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
:class:`~torchrl.data.RayReplayBuffer` instance should be used here.

Examples:
>>> from torch import nn
Expand Down Expand Up @@ -312,7 +319,9 @@ def __init__(
num_collectors: int = None,
update_after_each_batch=False,
max_weight_update_interval=-1,
replay_buffer: ReplayBuffer = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is not documented

):
self.frames_per_batch = frames_per_batch
if remote_configs is None:
remote_configs = DEFAULT_REMOTE_CLASS_CONFIG

Expand All @@ -321,6 +330,14 @@ def __init__(

if collector_kwargs is None:
collector_kwargs = {}
if replay_buffer is not None:
if isinstance(collector_kwargs, dict):
collector_kwargs.setdefault("replay_buffer", replay_buffer)
else:
collector_kwargs = [
ck.setdefault("replay_buffer", replay_buffer)
for ck in collector_kwargs
]

# Make sure input parameters are consistent
def check_consistency_with_num_collectors(param, param_name, num_collectors):
Expand Down Expand Up @@ -386,7 +403,8 @@ def check_list_length_consistency(*lists):
raise RuntimeError(
"ray library not found, unable to create a DistributedCollector. "
) from RAY_ERR
ray.init(**ray_init_config)
if not ray.is_initialized():
ray.init(**ray_init_config)
if not ray.is_initialized():
raise RuntimeError("Ray could not be initialized.")

Expand All @@ -400,6 +418,7 @@ def check_list_length_consistency(*lists):
collector_class.as_remote = as_remote
collector_class.print_remote_collector_info = print_remote_collector_info

self.replay_buffer = replay_buffer
self._local_policy = policy
if isinstance(self._local_policy, nn.Module):
policy_weights = TensorDict.from_module(self._local_policy)
Expand Down Expand Up @@ -557,7 +576,7 @@ def add_collectors(
policy,
other_params,
)
self._remote_collectors.extend([collector])
self._remote_collectors.append(collector)

def local_policy(self):
"""Returns local collector."""
Expand All @@ -577,17 +596,33 @@ def stop_remote_collectors(self):
) # This will interrupt any running tasks on the actor, causing them to fail immediately

def iterator(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should iterator raise an error for the async case? Or what happens when trying to iterate over this collector when replaybuffer was passed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should get None outputs, and the buffer is filled each time you call next()

def proc(data):
if self.split_trajs:
data = split_trajectories(data)
if self.postproc is not None:
data = self.postproc(data)
return data

if self._sync:
data = self._sync_iterator()
meth = self._sync_iterator
else:
data = self._async_iterator()
meth = self._async_iterator
yield from (proc(data) for data in meth())

if self.split_trajs:
data = split_trajectories(data)
if self.postproc is not None:
data = self.postproc(data)
async def _asyncio_iterator(self):
def proc(data):
if self.split_trajs:
data = split_trajectories(data)
if self.postproc is not None:
data = self.postproc(data)
return data

return data
if self._sync:
for d in self._sync_iterator():
yield proc(d)
else:
for d in self._async_iterator():
yield proc(d)

def _sync_iterator(self) -> Iterator[TensorDictBase]:
"""Collects one data batch per remote collector in each iteration."""
Expand Down Expand Up @@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
):
self.update_policy_weights_(rank)

self.shutdown()
if self._task is None:
self.shutdown()

_task = None

def start(self):
"""Starts the RayCollector."""
if self.replay_buffer is None:
raise RuntimeError("Replay buffer must be defined for asyncio execution.")
if self._task is None or self._task.done():
loop = asyncio.get_event_loop()
self._task = loop.create_task(self._run_iterator_silently())

async def _run_iterator_silently(self):
async for _ in self._asyncio_iterator():
# Process each item silently
continue

async def async_shutdown(self):
"""Finishes processes started by ray.init() during async execution."""
if self._task is not None:
await self._task
self.stop_remote_collectors()
ray.shutdown()

def _async_iterator(self) -> Iterator[TensorDictBase]:
"""Collects a data batch from a single remote collector in each iteration."""
Expand All @@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
ray.internal.free(
[future]
) # should not be necessary, deleted automatically when ref count is down to 0
self.collected_frames += out_td.numel()
self.collected_frames += self.frames_per_batch

yield out_td

Expand Down Expand Up @@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
# object_ref=ref,
# force=False,
# )

self.shutdown()
if self._task is None:
self.shutdown()

def update_policy_weights_(self, worker_rank=None) -> None:
"""Updates the weights of the worker nodes.
Expand Down
Loading
Loading