diff --git a/test/test_distributed.py b/test/test_distributed.py index 8e939cce42e..c49596bb26f 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -556,19 +556,21 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): total = 0 first_batch = None last_batch = None - for i, data in enumerate(collector): - total += data.numel() - assert data.numel() == frames_per_batch - if i == 0: - first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() - elif total == total_frames - frames_per_batch: - last_batch = data - assert (first_batch["action"] == 1).all(), first_batch["action"] - assert (last_batch["action"] == 2).all(), last_batch["action"] - collector.shutdown() - assert total == total_frames + try: + for i, data in enumerate(collector): + total += data.numel() + assert data.numel() == frames_per_batch + if i == 0: + first_batch = data + policy.weight.data += 1 + collector.update_policy_weights_() + elif total == total_frames - frames_per_batch: + last_batch = data + assert (first_batch["action"] == 1).all(), first_batch["action"] + assert (last_batch["action"] == 2).all(), last_batch["action"] + assert total == total_frames + finally: + collector.shutdown() @pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)]) @pytest.mark.parametrize( @@ -593,6 +595,34 @@ def test_ray_replaybuffer(self, storage, sampler, writer): if sampler is SamplerWithoutReplacement: assert sample["a"].unique().numel() == sample.numel() + # class CustomCollectorCls(SyncDataCollector): + # def __init__(self, create_env_fn, **kwargs): + # policy = lambda td: td.set("action", torch.full(td.shape, 2)) + # super().__init__(create_env_fn, policy, **kwargs) + + def test_ray_collector_policy_constructor(self): + n_collectors = 2 + frames_per_batch = 50 + total_frames = 300 + env = CountingEnv + + def policy_constructor(): + return lambda td: td.set("action", torch.full(td.shape, 2)) + + collector = self.distributed_class()( + [env] * n_collectors, + collector_class=SyncDataCollector, + policy_factory=policy_constructor, + total_frames=total_frames, + frames_per_batch=frames_per_batch, + **self.distributed_kwargs(), + ) + try: + for data in collector: + assert (data["action"] == 2).all() + finally: + collector.shutdown() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 37e47b2002e..58f67229413 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -21,7 +21,7 @@ from multiprocessing.managers import SyncManager from queue import Empty from textwrap import indent -from typing import Any, Callable, Iterator, Sequence +from typing import Any, Callable, Iterator, Sequence, TypeVar import numpy as np import torch @@ -86,6 +86,8 @@ def cudagraph_mark_step_begin(): _is_osx = sys.platform.startswith("darwin") +T = TypeVar("T") + class _Interruptor: """A class for managing the collection state of a process. @@ -343,7 +345,15 @@ class SyncDataCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -515,6 +525,7 @@ def __init__( policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: DEVICE_TYPING = None, @@ -558,8 +569,13 @@ def __init__( env.update_kwargs(create_env_kwargs) if policy is None: + if policy_factory is not None: + policy = policy_factory() + else: + policy = RandomPolicy(env.full_action_spec) + elif policy_factory is not None: + raise TypeError("policy_factory cannot be used with policy argument.") - policy = RandomPolicy(env.full_action_spec) if trust_policy is None: trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) self.trust_policy = trust_policy @@ -1429,17 +1445,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: self._iter = state_dict["iter"] def __repr__(self) -> str: - env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self.policy}", 4 * " ") - td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ") - string = ( - f"{self.__class__.__name__}(" - f"\n{env_str}," - f"\n{policy_str}," - f"\n{td_out_str}," - f"\nexploration={self.exploration_type})" - ) - return string + try: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self.policy}", 4 * " ") + td_out_str = indent( + f"td_out={getattr(self, '_final_rollout', None)}", 4 * " " + ) + string = ( + f"{self.__class__.__name__}(" + f"\n{env_str}," + f"\n{policy_str}," + f"\n{td_out_str}," + f"\nexploration={self.exploration_type})" + ) + return string + except AttributeError: + return f"{type(self).__name__}(not_init)" class _MultiDataCollector(DataCollectorBase): @@ -1469,7 +1490,18 @@ class _MultiDataCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + .. warning:: `policy_factory` is currently not compatible with multiprocessed data + collectors. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, optional): A keyword-only argument representing the @@ -1612,6 +1644,7 @@ def __init__( policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, @@ -1695,27 +1728,36 @@ def __init__( self._get_weights_fn_dict = {} if trust_policy is None: - trust_policy = isinstance(policy, CudaGraphModule) + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) self.trust_policy = trust_policy - for policy_device, env_maker, env_maker_kwargs in zip( - self.policy_device, self.create_env_fn, self.create_env_kwargs - ): - (policy_copy, get_weights_fn,) = self._get_policy_and_device( - policy=policy, - policy_device=policy_device, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) - if type(policy_copy) is not type(policy): - policy = policy_copy - weights = ( - TensorDict.from_module(policy_copy) - if isinstance(policy_copy, nn.Module) - else TensorDict() + if policy_factory is not None and policy is not None: + raise TypeError("policy_factory and policy are mutually exclusive") + elif policy_factory is None: + for policy_device, env_maker, env_maker_kwargs in zip( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + (policy_copy, get_weights_fn,) = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if type(policy_copy) is not type(policy): + policy = policy_copy + weights = ( + TensorDict.from_module(policy_copy) + if isinstance(policy_copy, nn.Module) + else TensorDict() + ) + self._policy_weights_dict[policy_device] = weights + self._get_weights_fn_dict[policy_device] = get_weights_fn + else: + # TODO + raise NotImplementedError( + "weight syncing is not supported for multiprocessed data collectors at the " + "moment." ) - self._policy_weights_dict[policy_device] = weights - self._get_weights_fn_dict[policy_device] = get_weights_fn self.policy = policy remainder = 0 @@ -2782,7 +2824,15 @@ class aSyncDataCollector(MultiaSyncDataCollector): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, optional): A keyword-only argument representing the @@ -2888,8 +2938,10 @@ class aSyncDataCollector(MultiaSyncDataCollector): def __init__( self, create_env_fn: Callable[[], EnvBase], - policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]), + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int | None = -1, device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, @@ -2914,6 +2966,7 @@ def __init__( super().__init__( create_env_fn=[create_env_fn], policy=policy, + policy_factory=policy_factory, total_frames=total_frames, create_env_kwargs=[create_env_kwargs], max_frames_per_traj=max_frames_per_traj, diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 5ec55e23a16..00ea799d55f 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -14,7 +14,7 @@ from typing import Callable, OrderedDict import torch.cuda -from tensordict import TensorDict +from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE @@ -270,7 +270,15 @@ class DistributedDataCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -406,8 +414,9 @@ class DistributedDataCollector(DataCollectorBase): def __init__( self, create_env_fn, - policy, + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: torch.device | list[torch.device] = None, diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 1d56e066f0e..279ef9e4aae 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -124,7 +124,7 @@ class RayCollector(DataCollectorBase): Args: create_env_fn (Callable or List[Callabled]): list of Callables, each returning an instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Policy to be executed in the environment. + policy (Callable, optional): Policy to be executed in the environment. Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. If ``None`` is provided, the policy used will be a :class:`~torchrl.collectors.RandomPolicy` instance with the environment @@ -144,7 +144,15 @@ class RayCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int, Optional): lower bound of the total number of frames returned by the collector. @@ -296,8 +304,9 @@ class RayCollector(DataCollectorBase): def __init__( self, create_env_fn: Callable | EnvBase | list[Callable] | list[EnvBase], - policy: Callable[[TensorDict], TensorDict], + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: torch.device | list[torch.device] = None, @@ -415,8 +424,16 @@ def check_list_length_consistency(*lists): collector_class = MultiSyncDataCollector elif collector_class == "single": collector_class = SyncDataCollector - collector_class.as_remote = as_remote - collector_class.print_remote_collector_info = print_remote_collector_info + elif not isinstance(collector_class, type) or not issubclass( + collector_class, DataCollectorBase + ): + raise TypeError( + "The collector_class must be an instance of DataCollectorBase." + ) + if not hasattr(collector_class, "as_remote"): + collector_class.as_remote = as_remote + if not hasattr(collector_class, "print_remote_collector_info"): + collector_class.print_remote_collector_info = print_remote_collector_info self.replay_buffer = replay_buffer self._local_policy = policy @@ -456,6 +473,7 @@ def check_list_length_consistency(*lists): # update collector kwargs for i, collector_kwarg in enumerate(self.collector_kwargs): + collector_kwarg["policy_factory"] = policy_factory collector_kwarg["max_frames_per_traj"] = max_frames_per_traj collector_kwarg["init_random_frames"] = ( init_random_frames // self.num_collectors @@ -545,11 +563,12 @@ def policy_device(self, value): self._policy_device = [value] * self.num_collectors @staticmethod - def _make_collector(cls, env_maker, policy, other_params): + def _make_collector(cls, *, env_maker, policy, other_params): """Create a single collector instance.""" + if policy is not None: + other_params["policy"] = policy collector = cls( env_maker, - policy, total_frames=-1, **other_params, ) @@ -570,11 +589,15 @@ def add_collectors( cls = self.collector_class.as_remote(remote_config).remote collector = self._make_collector( cls, - [env_maker] * num_envs - if self.collector_class is not SyncDataCollector + env_maker=[env_maker] * num_envs + if num_envs > 1 + or ( + isinstance(self.collector_class, type) + and not issubclass(self.collector_class, SyncDataCollector) + ) else env_maker, - policy, - other_params, + policy=policy, + other_params=other_params, ) self._remote_collectors.append(collector) diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index ee73cfdf4e7..f205325ed30 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -14,6 +14,7 @@ from copy import copy, deepcopy from typing import Callable, OrderedDict +from tensordict import TensorDictBase from torchrl._utils import logger as torchrl_logger from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( @@ -117,7 +118,15 @@ class RPCDataCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -261,8 +270,9 @@ class RPCDataCollector(DataCollectorBase): def __init__( self, create_env_fn, - policy, + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: torch.device | list[torch.device] = None, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 0a2215e0abe..b15045f8288 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -14,7 +14,7 @@ from typing import Callable, OrderedDict import torch.cuda -from tensordict import TensorDict +from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE @@ -150,7 +150,15 @@ class DistributedSyncDataCollector(DataCollectorBase): - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the :arg:`policy_factory` should be used instead. + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + frames_per_batch (int): A keyword-only argument representing the total number of elements in a batch. total_frames (int): A keyword-only argument representing the total @@ -271,8 +279,9 @@ class DistributedSyncDataCollector(DataCollectorBase): def __init__( self, create_env_fn, - policy, + policy: Callable[[TensorDictBase], TensorDictBase] | None = None, *, + policy_factory: Callable[[], Callable] | None = None, frames_per_batch: int, total_frames: int = -1, device: torch.device | list[torch.device] = None,