diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index abcd96e6457..52cca7c5ce6 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -20,6 +20,8 @@ widely used replay buffers: PrioritizedReplayBuffer TensorDictReplayBuffer TensorDictPrioritizedReplayBuffer + RayReplayBuffer + RemoteTensorDictReplayBuffer Composable Replay Buffers ------------------------- diff --git a/examples/distributed/collectors/multi_nodes/ray_buffer_infra.py b/examples/distributed/collectors/multi_nodes/ray_buffer_infra.py new file mode 100644 index 00000000000..0a17ef7279e --- /dev/null +++ b/examples/distributed/collectors/multi_nodes/ray_buffer_infra.py @@ -0,0 +1,82 @@ +""" +Example use of an ever-running, fully async, distributed collector +================================================================== + +This example demonstrates how to set up and use a distributed collector +with Ray in a fully asynchronous manner. The collector continuously gathers +data from a gym environment and stores it in a replay buffer, allowing for +concurrent processing and data collection. + +Key Components: +1. **Environment Factory**: A simple function that creates instances of the + `GymEnv` environment. In this example, we use the "Pendulum-v1" environment. +2. **Policy Definition**: A `TensorDictModule` that defines the policy network. + Here, a simple linear layer is used to map observations to actions. +3. **Replay Buffer**: A `RayReplayBuffer` that stores collected data for later + use, such as training a reinforcement learning model. +4. **Distributed Collector**: A `RayCollector` that manages the distributed + collection of data. It is configured with remote resources and interacts + with the environment and policy to gather data. +5. **Asynchronous Execution**: The collector runs in the background, allowing + the main program to perform other tasks concurrently. The example includes + a loop that waits for data to be available in the buffer and samples it. +6. **Graceful Shutdown**: The collector is shut down asynchronously, ensuring + that all resources are properly released. + +This setup is useful for scenarios where you need to collect data from +multiple environments in parallel, leveraging Ray's distributed computing +capabilities to scale efficiently. + +""" +import asyncio + +from tensordict.nn import TensorDictModule +from torch import nn +from torchrl.collectors.distributed.ray import RayCollector +from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer +from torchrl.envs.libs.gym import GymEnv + + +async def main(): + # 1. Create environment factory + def env_maker(): + return GymEnv("Pendulum-v1", device="cpu") + + policy = TensorDictModule( + nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] + ) + + buffer = RayReplayBuffer() + + # 2. Define distributed collector + remote_config = { + "num_cpus": 1, + "num_gpus": 0, + "memory": 5 * 1024**3, + "object_store_memory": 2 * 1024**3, + } + distributed_collector = RayCollector( + [env_maker], + policy, + total_frames=600, + frames_per_batch=200, + remote_configs=remote_config, + replay_buffer=buffer, + ) + + print("start") + distributed_collector.start() + + while True: + while not len(buffer): + print("waiting") + await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep + print("sample", buffer.sample(32)) + # break at some point + break + + await distributed_collector.async_shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/distributed/collectors/multi_nodes/ray_collect.py b/examples/distributed/collectors/multi_nodes/ray_collect.py index e70b0a58d09..6487b6b6459 100644 --- a/examples/distributed/collectors/multi_nodes/ray_collect.py +++ b/examples/distributed/collectors/multi_nodes/ray_collect.py @@ -27,7 +27,7 @@ def env_maker(): # 2. Define distributed collector remote_config = { "num_cpus": 1, - "num_gpus": 0.2, + "num_gpus": 0, "memory": 5 * 1024**3, "object_store_memory": 2 * 1024**3, } @@ -36,6 +36,7 @@ def env_maker(): policy, total_frames=10000, frames_per_batch=200, + remote_configs=remote_config, ) # Sample batches until reaching total_frames diff --git a/test/test_distributed.py b/test/test_distributed.py index 2529e4e3d7e..8e939cce42e 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -11,10 +11,19 @@ import os import sys import time +from functools import partial import pytest +from tensordict import TensorDict from tensordict.nn import TensorDictModuleBase from torchrl._utils import logger as torchrl_logger +from torchrl.data import ( + LazyTensorStorage, + RandomSampler, + RayReplayBuffer, + RoundRobinWriter, + SamplerWithoutReplacement, +) try: import ray @@ -435,6 +444,15 @@ class TestRayCollector(DistributedCollectorBase): to avoid potential deadlocks when combining Ray and multiprocessing. """ + @pytest.fixture(autouse=True, scope="class") + def start_ray(self): + from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG + + ray.init(**DEFAULT_RAY_INIT_CONFIG) + + yield + ray.shutdown() + @classmethod def distributed_class(cls) -> type: return RayCollector @@ -552,6 +570,29 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): collector.shutdown() assert total == total_frames + @pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)]) + @pytest.mark.parametrize( + "sampler", [None, partial(RandomSampler), SamplerWithoutReplacement] + ) + @pytest.mark.parametrize("writer", [None, partial(RoundRobinWriter)]) + def test_ray_replaybuffer(self, storage, sampler, writer): + kwargs = self.distributed_kwargs() + kwargs["remote_config"] = kwargs.pop("remote_configs") + rb = RayReplayBuffer( + storage=storage, + sampler=sampler, + writer=writer, + batch_size=32, + **kwargs, + ) + td = TensorDict(a=torch.arange(100, 200), batch_size=[100]) + index = rb.extend(td) + assert (index == torch.arange(100)).all() + for _ in range(10): + sample = rb.sample() + if sampler is SamplerWithoutReplacement: + assert sample["a"].unique().numel() == sample.numel() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 67dff40e9de..37e47b2002e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -267,7 +267,8 @@ def next(self): self._iterator = iter(self) out = next(self._iterator) # if any, we don't want the device ref to be passed in distributed settings - out.clear_device_() + if out is not None: + out.clear_device_() return out except StopIteration: return None @@ -432,7 +433,7 @@ class SyncDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules @@ -1430,7 +1431,7 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: def __repr__(self) -> str: env_str = indent(f"env={self.env}", 4 * " ") policy_str = indent(f"policy={self.policy}", 4 * " ") - td_out_str = indent(f"td_out={self._final_rollout}", 4 * " ") + td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ") string = ( f"{self.__class__.__name__}(" f"\n{env_str}," @@ -1586,7 +1587,7 @@ class _MultiDataCollector(DataCollectorBase): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 1716609026b..1d56e066f0e 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import warnings from typing import Callable, Iterator, OrderedDict @@ -21,6 +22,7 @@ SyncDataCollector, ) from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator @@ -256,6 +258,11 @@ class RayCollector(DataCollectorBase): parameters being updated for a certain time even if ``update_after_each_batch`` is turned on. Defaults to -1 (no forced update). + replay_buffer (RayReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. Defaults to ``None``. + + .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a + :class:`~torchrl.data.RayReplayBuffer` instance should be used here. Examples: >>> from torch import nn @@ -312,7 +319,9 @@ def __init__( num_collectors: int = None, update_after_each_batch=False, max_weight_update_interval=-1, + replay_buffer: ReplayBuffer = None, ): + self.frames_per_batch = frames_per_batch if remote_configs is None: remote_configs = DEFAULT_REMOTE_CLASS_CONFIG @@ -321,6 +330,14 @@ def __init__( if collector_kwargs is None: collector_kwargs = {} + if replay_buffer is not None: + if isinstance(collector_kwargs, dict): + collector_kwargs.setdefault("replay_buffer", replay_buffer) + else: + collector_kwargs = [ + ck.setdefault("replay_buffer", replay_buffer) + for ck in collector_kwargs + ] # Make sure input parameters are consistent def check_consistency_with_num_collectors(param, param_name, num_collectors): @@ -386,7 +403,8 @@ def check_list_length_consistency(*lists): raise RuntimeError( "ray library not found, unable to create a DistributedCollector. " ) from RAY_ERR - ray.init(**ray_init_config) + if not ray.is_initialized(): + ray.init(**ray_init_config) if not ray.is_initialized(): raise RuntimeError("Ray could not be initialized.") @@ -400,6 +418,7 @@ def check_list_length_consistency(*lists): collector_class.as_remote = as_remote collector_class.print_remote_collector_info = print_remote_collector_info + self.replay_buffer = replay_buffer self._local_policy = policy if isinstance(self._local_policy, nn.Module): policy_weights = TensorDict.from_module(self._local_policy) @@ -557,7 +576,7 @@ def add_collectors( policy, other_params, ) - self._remote_collectors.extend([collector]) + self._remote_collectors.append(collector) def local_policy(self): """Returns local collector.""" @@ -577,17 +596,33 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): + def proc(data): + if self.split_trajs: + data = split_trajectories(data) + if self.postproc is not None: + data = self.postproc(data) + return data + if self._sync: - data = self._sync_iterator() + meth = self._sync_iterator else: - data = self._async_iterator() + meth = self._async_iterator + yield from (proc(data) for data in meth()) - if self.split_trajs: - data = split_trajectories(data) - if self.postproc is not None: - data = self.postproc(data) + async def _asyncio_iterator(self): + def proc(data): + if self.split_trajs: + data = split_trajectories(data) + if self.postproc is not None: + data = self.postproc(data) + return data - return data + if self._sync: + for d in self._sync_iterator(): + yield proc(d) + else: + for d in self._async_iterator(): + yield proc(d) def _sync_iterator(self) -> Iterator[TensorDictBase]: """Collects one data batch per remote collector in each iteration.""" @@ -634,7 +669,30 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: ): self.update_policy_weights_(rank) - self.shutdown() + if self._task is None: + self.shutdown() + + _task = None + + def start(self): + """Starts the RayCollector.""" + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for asyncio execution.") + if self._task is None or self._task.done(): + loop = asyncio.get_event_loop() + self._task = loop.create_task(self._run_iterator_silently()) + + async def _run_iterator_silently(self): + async for _ in self._asyncio_iterator(): + # Process each item silently + continue + + async def async_shutdown(self): + """Finishes processes started by ray.init() during async execution.""" + if self._task is not None: + await self._task + self.stop_remote_collectors() + ray.shutdown() def _async_iterator(self) -> Iterator[TensorDictBase]: """Collects a data batch from a single remote collector in each iteration.""" @@ -658,7 +716,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: ray.internal.free( [future] ) # should not be necessary, deleted automatically when ref count is down to 0 - self.collected_frames += out_td.numel() + self.collected_frames += self.frames_per_batch yield out_td @@ -689,8 +747,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: # object_ref=ref, # force=False, # ) - - self.shutdown() + if self._task is None: + self.shutdown() def update_policy_weights_(self, worker_rank=None) -> None: """Updates the weights of the worker nodes. diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 8ae832b257e..6b3b482560d 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -49,6 +49,7 @@ PrioritizedSampler, PrioritizedSliceSampler, RandomSampler, + RayReplayBuffer, RemoteTensorDictReplayBuffer, ReplayBuffer, ReplayBufferEnsemble, @@ -159,6 +160,7 @@ "QueryModule", "RandomProjectionHash", "RandomSampler", + "RayReplayBuffer", "RemoteTensorDictReplayBuffer", "ReplayBuffer", "ReplayBufferEnsemble", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index d3e8f18cb00..6e7bff8eac0 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -12,6 +12,7 @@ StorageEnsembleCheckpointer, TensorStorageCheckpointer, ) +from .ray_buffer import RayReplayBuffer from .replay_buffers import ( PrioritizedReplayBuffer, RemoteTensorDictReplayBuffer, @@ -57,6 +58,7 @@ "StorageCheckpointerBase", "StorageEnsembleCheckpointer", "TensorStorageCheckpointer", + "RayReplayBuffer", "PrioritizedReplayBuffer", "RemoteTensorDictReplayBuffer", "ReplayBuffer", diff --git a/torchrl/data/replay_buffers/ray_buffer.py b/torchrl/data/replay_buffers/ray_buffer.py new file mode 100644 index 00000000000..1ceddf60d18 --- /dev/null +++ b/torchrl/data/replay_buffers/ray_buffer.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Any, Callable + +import torch + +from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer +from torchrl.envs.transforms.transforms import Transform + +RAY_ERR = None +try: + import ray + + _has_ray = True +except ImportError as err: + _has_ray = False + RAY_ERR = err + + +@classmethod +def as_remote(cls, remote_config=None): + """Creates an instance of a remote ray class. + + Args: + cls (Python Class): class to be remotely instantiated. + remote_config (dict): the quantity of CPU cores to reserve for this class. + Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`. + + Returns: + A function that creates ray remote class instances. + """ + if remote_config is None: + from torchrl.collectors.distributed.ray import DEFAULT_REMOTE_CLASS_CONFIG + + remote_config = DEFAULT_REMOTE_CLASS_CONFIG + remote_collector = ray.remote(**remote_config)(cls) + remote_collector.is_remote = True + return remote_collector + + +ReplayBuffer.as_remote = as_remote + + +class RayReplayBuffer(ReplayBuffer): + """A Ray implementation of the Replay Buffer that can be extended and sampled remotely. + + Keyword Args: + ray_init_config (dict[str, Any], optiona): keyword arguments to pass to `ray.init()`. + remote_config (dict[str, Any], optiona): keyword arguments to pass to `cls.as_remote()`. + Defaults to `torchrl.collectors.distributed.ray.DEFAULT_REMOTE_CLASS_CONFIG`. + + .. seealso:: :class:`~torchrl.data.ReplayBuffer` for a list of other keyword arguments. + + The writer, sampler and storage should be passed as constructors to prevent serialization issues. + Transforms constructors should be passed through the `transform_factory` argument. + + Example: + >>> import asyncio + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors.distributed.ray import RayCollector + >>> from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer + >>> from torchrl.envs.libs.gym import GymEnv + >>> + >>> async def main(): + ... # 1. Create environment factory + ... def env_maker(): + ... return GymEnv("Pendulum-v1", device="cpu") + ... + ... policy = TensorDictModule( + ... nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"] + ... ) + ... + ... buffer = RayReplayBuffer() + ... + ... # 2. Define distributed collector + ... remote_config = { + ... "num_cpus": 1, + ... "num_gpus": 0, + ... "memory": 5 * 1024**3, + ... "object_store_memory": 2 * 1024**3, + ... } + ... distributed_collector = RayCollector( + ... [env_maker], + ... policy, + ... total_frames=600, + ... frames_per_batch=200, + ... remote_configs=remote_config, + ... replay_buffer=buffer, + ... ) + ... + ... print("start") + ... distributed_collector.start() + ... + ... while True: + ... while not len(buffer): + ... print("waiting") + ... await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep + ... print("sample", buffer.sample(32)) + ... # break at some point + ... break + ... + ... await distributed_collector.async_shutdown() + >>> + >>> if __name__ == "__main__": + ... asyncio.run(main()) + + """ + + def __init__( + self, + *args, + ray_init_config: dict[str, Any] | None = None, + remote_config: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if not _has_ray: + raise RuntimeError( + "ray library not found, unable to create a RayReplayBuffer. " + ) from RAY_ERR + if not ray.is_initialized(): + if ray_init_config is None: + from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG + + ray_init_config = DEFAULT_RAY_INIT_CONFIG + ray.init(**ray_init_config) + + remote_cls = ReplayBuffer.as_remote(remote_config).remote + self._rb = remote_cls(*args, **kwargs) + + def sample(self, *args, **kwargs): + pending_task = self._rb.sample.remote(*args, **kwargs) + return ray.get(pending_task) + + def extend(self, *args, **kwargs): + pending_task = self._rb.extend.remote(*args, **kwargs) + return ray.get(pending_task) + + def add(self, *args, **kwargs): + return ray.get(self._rb.add.remote(*args, **kwargs)) + + def update_priority(self, *args, **kwargs): + return ray.get(self._rb.update_priority.remote(*args, **kwargs)) + + def append_transform(self, *args, **kwargs): + return ray.get(self._rb.append_transform.remote(*args, **kwargs)) + + def dumps(self, path): + return ray.get(self._rb.dumps.remote(path)) + + def dump(self, path): + return ray.get(self._rb.dump.remote(path)) + + def loads(self, path): + return ray.get(self._rb.loads.remote(path)) + + def load(self, *args, **kwargs): + return ray.get(self._rb.load.remote(*args, **kwargs)) + + def empty(self): + return ray.get(self._rb.empty.remote()) + + def insert_transform( + self, + index: int, + transform: Transform, # noqa-F821 + *, + invert: bool = False, + ) -> ReplayBuffer: + return ray.get( + self._rb.insert_transform.remote(index, transform, invert=invert) + ) + + def mark_update(self, index: int | torch.Tensor) -> None: + return ray.get(self._rb.mark_update.remote(index)) + + def register_load_hook(self, hook: Callable[[Any], Any]): + return ray.get(self._rb.register_load_hook.remote(hook)) + + def register_save_hook(self, hook: Callable[[Any], Any]): + return ray.get(self._rb.register_save_hook.remote(hook)) + + def save(self, path: str): + return ray.get(self._rb.save.remote(path)) + + def set_rng(self, generator): + return ray.get(self._rb.set_rng.remote(generator)) + + def set_sampler(self, sampler): + return ray.get(self._rb.set_sampler.remote(sampler)) + + def set_storage(self, storage): + return ray.get(self._rb.set_storage.remote(storage)) + + def set_writer(self, writer): + return ray.get(self._rb.set_writer.remote(writer)) + + def share(self, shared: bool = True): + return ray.get(self._rb.share.remote(shared)) + + def state_dict(self): + return ray.get(self._rb.state_dict.remote()) + + def __len__(self): + return ray.get(self._rb.__len__.remote()) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 20e029fc535..98b90decede 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -23,6 +23,8 @@ except ImportError: from torch._dynamo import is_compiling +from functools import partial + from tensordict import ( is_tensor_collection, is_tensorclass, @@ -66,21 +68,24 @@ WriterEnsemble, ) from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs.transforms.transforms import _InvertTransform +from torchrl.envs.transforms.transforms import _InvertTransform, Transform class ReplayBuffer: """A generic, composable replay buffer class. Keyword Args: - storage (Storage, optional): the storage to be used. If none is provided - a default :class:`~torchrl.data.replay_buffers.ListStorage` with + storage (Storage, Callable[[], Storage], optional): the storage to be used. + If a callable is passed, it is used as constructor for the storage. + If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with ``max_size`` of ``1_000`` will be created. - sampler (Sampler, optional): the sampler to be used. If none is provided, - a default :class:`~torchrl.data.replay_buffers.RandomSampler` + sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used. + If a callable is passed, it is used as constructor for the sampler. + If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler` will be used. - writer (Writer, optional): the writer to be used. If none is provided - a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter` + writer (Writer, Callable[[], Writer], optional): the writer to be used. + If a callable is passed, it is used as constructor for the writer. + If none is provided a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched @@ -90,12 +95,17 @@ class ReplayBuffer: samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. Defaults to None (no prefetching). - transform (Transform, optional): Transform to be executed when + transform (Transform or Callable[[Any], Any], optional): Transform to be executed when :meth:`sample` is called. To chain transforms use the :class:`~torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` content. A generic callable can also be passed if the replay buffer is used with PyTree structures (see example below). + Unlike storages, writers and samplers, transform constructors must + be passed as separate keyword argument :attr:`transform_factory`, + as it is impossible to distinguish a constructor from a transform. + transform_factory (Callable[[], Callable], optional): a factory for the + transform. Exclusive with :attr:`transform`. batch_size (int, optional): the batch size to be used when sample() is called. @@ -215,32 +225,28 @@ class ReplayBuffer: def __init__( self, *, - storage: Storage | None = None, - sampler: Sampler | None = None, - writer: Writer | None = None, + storage: Storage | Callable[[], Storage] | None = None, + sampler: Sampler | Callable[[], Sampler] | None = None, + writer: Writer | Callable[[], Writer] | None = None, collate_fn: Callable | None = None, pin_memory: bool = False, prefetch: int | None = None, - transform: Transform | None = None, # noqa-F821 + transform: Transform | Callable | None = None, # noqa-F821 + transform_factory: Callable[[], Transform | Callable] + | None = None, # noqa-F821 batch_size: int | None = None, dim_extend: int | None = None, - checkpointer: StorageCheckpointerBase | None = None, # noqa: F821 + checkpointer: StorageCheckpointerBase # noqa: F821 + | Callable[[], StorageCheckpointerBase] # noqa: F821 + | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, compilable: bool = None, ) -> None: - self._storage = ( - storage - if storage is not None - else ListStorage(max_size=1_000, compilable=compilable) - ) + self._storage = self._maybe_make_storage(storage, compilable=compilable) self._storage.attach(self) - self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = ( - writer - if writer is not None - else RoundRobinWriter(compilable=bool(compilable)) - ) + self._sampler = self._maybe_make_sampler(sampler) + self._writer = self._maybe_make_writer(writer) self._writer.register_storage(self._storage) self._get_collate_fn(collate_fn) @@ -257,24 +263,8 @@ def __init__( self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() - from torchrl.envs.transforms.transforms import ( - _CallableTransform, - Compose, - Transform, - ) - if transform is None: - transform = Compose() - elif not isinstance(transform, Compose): - if not isinstance(transform, Transform) and callable(transform): - transform = _CallableTransform(transform) - elif not isinstance(transform, Transform): - raise RuntimeError( - "transform must be either a Transform instance or a callable." - ) - transform = Compose(transform) - transform.eval() - self._transform = transform + self._transform = self._maybe_make_transform(transform, transform_factory) if batch_size is None and prefetch: raise ValueError( @@ -299,6 +289,81 @@ def __init__( self._storage.checkpointer = checkpointer self.set_rng(generator=generator) + def _maybe_make_storage( + self, storage: Storage | Callable[[], Storage] | None, compilable + ) -> Storage: + if storage is None: + return ListStorage(max_size=1_000, compilable=compilable) + elif isinstance(storage, Storage): + return storage + elif callable(storage): + storage = storage() + if not isinstance(storage, Storage): + raise TypeError( + "storage must be either a Storage or a callable returning a storage instance." + ) + return storage + + def _maybe_make_sampler( + self, sampler: Sampler | Callable[[], Sampler] | None + ) -> Sampler: + if sampler is None: + return RandomSampler() + elif isinstance(sampler, Sampler): + return sampler + elif callable(sampler): + sampler = sampler() + if not isinstance(sampler, Sampler): + raise TypeError( + "sampler must be either a Sampler or a callable returning a sampler instance." + ) + return sampler + + def _maybe_make_writer( + self, writer: Writer | Callable[[], Writer] | None + ) -> Writer: + if writer is None: + return RoundRobinWriter() + elif isinstance(writer, Writer): + return writer + elif callable(writer): + writer = writer() + if not isinstance(writer, Writer): + raise TypeError( + "writer must be either a Writer or a callable returning a writer instance." + ) + return writer + + def _maybe_make_transform( + self, + transform: Transform | Callable[[], Transform] | None, + transform_factory: Callable | None, + ) -> Transform: + from torchrl.envs.transforms.transforms import ( + _CallableTransform, + Compose, + Transform, + ) + + if transform_factory is not None: + if transform is not None: + raise TypeError( + "transform and transform_factory cannot be used simultaneously" + ) + transform = transform_factory() + if transform is None: + transform = Compose() + elif not isinstance(transform, Compose): + if not isinstance(transform, Transform) and callable(transform): + transform = _CallableTransform(transform) + elif not isinstance(transform, Transform): + raise RuntimeError( + "transform must be either a Transform instance or a callable." + ) + transform = Compose(transform) + transform.eval() + return transform + def share(self, shared: bool = True): self.shared = shared if self.shared: @@ -390,18 +455,25 @@ def write_count(self): def __repr__(self) -> str: from torchrl.envs.transforms import Compose - storage = textwrap.indent(f"storage={self._storage}", " " * 4) - writer = textwrap.indent(f"writer={self._writer}", " " * 4) - sampler = textwrap.indent(f"sampler={self._sampler}", " " * 4) - if self._transform is not None and not ( - isinstance(self._transform, Compose) and not len(self._transform) + storage = textwrap.indent(f"storage={getattr(self, '_storage', None)}", " " * 4) + writer = textwrap.indent(f"writer={getattr(self, '_writer', None)}", " " * 4) + sampler = textwrap.indent(f"sampler={getattr(self, '_sampler', None)}", " " * 4) + if getattr(self, "_transform", None) is not None and not ( + isinstance(self._transform, Compose) + and not len(getattr(self, "_transform", None)) ): - transform = textwrap.indent(f"transform={self._transform}", " " * 4) + transform = textwrap.indent( + f"transform={getattr(self, '_transform', None)}", " " * 4 + ) transform = f"\n{self._transform}, " else: transform = "" - batch_size = textwrap.indent(f"batch_size={self._batch_size}", " " * 4) - collate_fn = textwrap.indent(f"collate_fn={self._collate_fn}", " " * 4) + batch_size = textwrap.indent( + f"batch_size={getattr(self, '_batch_size', None)}", " " * 4 + ) + collate_fn = textwrap.indent( + f"collate_fn={getattr(self, '_collate_fn', None)}", " " * 4 + ) return f"{self.__class__.__name__}(\n{storage}, \n{sampler}, \n{writer}, {transform}\n{batch_size}, \n{collate_fn})" @pin_memory_output @@ -833,7 +905,7 @@ def __iter__(self): def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() - if self._rng is not None: + if getattr(self, "_rng", None) is not None: rng_state = TensorDict( rng_state=self._rng.get_state().clone(), device=self._rng.device, @@ -1030,13 +1102,17 @@ class TensorDictReplayBuffer(ReplayBuffer): """TensorDict-specific wrapper around the :class:`~torchrl.data.ReplayBuffer` class. Keyword Args: - storage (Storage, optional): the storage to be used. If none is provided - a default :class:`~torchrl.data.replay_buffers.ListStorage` with + storage (Storage, Callable[[], Storage], optional): the storage to be used. + If a callable is passed, it is used as constructor for the storage. + If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with ``max_size`` of ``1_000`` will be created. - sampler (Sampler, optional): the sampler to be used. If none is provided - a default RandomSampler() will be used. - writer (Writer, optional): the writer to be used. If none is provided - a default :class:`~torchrl.data.replay_buffers.RoundRobinWriter` + sampler (Sampler, Callable[[], Sampler], optional): the sampler to be used. + If a callable is passed, it is used as constructor for the sampler. + If none is provided, a default :class:`~torchrl.data.replay_buffers.RandomSampler` + will be used. + writer (Writer, Callable[[], Writer], optional): the writer to be used. + If a callable is passed, it is used as constructor for the writer. + If none is provided a default :class:`~torchrl.data.replay_buffers.TensorDictRoundRobinWriter` will be used. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched @@ -1046,13 +1122,17 @@ class TensorDictReplayBuffer(ReplayBuffer): samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. Defaults to None (no prefetching). - transform (Transform, optional): Transform to be executed when - sample() is called. + transform (Transform or Callable[[Any], Any], optional): Transform to be executed when + :meth:`sample` is called. To chain transforms use the :class:`~torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` - content. If used with other structures, the transforms should be - encoded with a ``"data"`` leading key that will be used to - construct a tensordict from the non-tensordict content. + content. A generic callable can also be passed if the replay buffer + is used with PyTree structures (see example below). + Unlike storages, writers and samplers, transform constructors must + be passed as separate keyword argument :attr:`transform_factory`, + as it is impossible to distinguish a constructor from a transform. + transform_factory (Callable[[], Callable], optional): a factory for the + transform. Exclusive with :attr:`transform`. batch_size (int, optional): the batch size to be used when sample() is called. @@ -1169,10 +1249,9 @@ class TensorDictReplayBuffer(ReplayBuffer): def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None: writer = kwargs.get("writer", None) if writer is None: - kwargs["writer"] = TensorDictRoundRobinWriter( - compilable=kwargs.get("compilable") + kwargs["writer"] = partial( + TensorDictRoundRobinWriter, compilable=kwargs.get("compilable") ) - super().__init__(**kwargs) self.priority_key = priority_key @@ -1381,8 +1460,9 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): beta (:obj:`float`): importance sampling negative exponent. eps (:obj:`float`): delta added to the priorities to ensure that the buffer does not contain null priorities. - storage (Storage, optional): the storage to be used. If none is provided - a default :class:`~torchrl.data.replay_buffers.ListStorage` with + storage (Storage, Callable[[], Storage], optional): the storage to be used. + If a callable is passed, it is used as constructor for the storage. + If none is provided a default :class:`~torchrl.data.replay_buffers.ListStorage` with ``max_size`` of ``1_000`` will be created. collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s)/outputs. Used when using batched @@ -1392,13 +1472,17 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. Defaults to None (no prefetching). - transform (Transform, optional): Transform to be executed when - sample() is called. + transform (Transform or Callable[[Any], Any], optional): Transform to be executed when + :meth:`sample` is called. To chain transforms use the :class:`~torchrl.envs.Compose` class. Transforms should be used with :class:`tensordict.TensorDict` - content. If used with other structures, the transforms should be - encoded with a ``"data"`` leading key that will be used to - construct a tensordict from the non-tensordict content. + content. A generic callable can also be passed if the replay buffer + is used with PyTree structures (see example below). + Unlike storages, writers and samplers, transform constructors must + be passed as separate keyword argument :attr:`transform_factory`, + as it is impossible to distinguish a constructor from a transform. + transform_factory (Callable[[], Callable], optional): a factory for the + transform. Exclusive with :attr:`transform`. batch_size (int, optional): the batch size to be used when sample() is called. @@ -1530,8 +1614,7 @@ def __init__( shared: bool = False, compilable: bool = False, ) -> None: - if storage is None: - storage = ListStorage(max_size=1_000) + storage = self._maybe_make_storage(storage, compilable=compilable) sampler = PrioritizedSampler( storage.max_size, alpha, beta, eps, reduction=reduction ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index c9f6715984d..54e01b00718 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -344,7 +344,10 @@ def __getstate__(self): return state def __repr__(self): - return f"{self.__class__.__name__}(items=[{self._storage[0]}, ...])" + storage = getattr(self, "_storage", [None]) + if not storage: + return f"{self.__class__.__name__}()" + return f"{self.__class__.__name__}(items=[{storage[0]}, ...])" def contains(self, item): if isinstance(item, int):