Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,21 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
total = 0
first_batch = None
last_batch = None
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
collector.update_policy_weights_()
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
assert (last_batch["action"] == 2).all(), last_batch["action"]
collector.shutdown()
assert total == total_frames
try:
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
collector.update_policy_weights_()
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
assert (last_batch["action"] == 2).all(), last_batch["action"]
assert total == total_frames
finally:
collector.shutdown()

@pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)])
@pytest.mark.parametrize(
Expand All @@ -593,6 +595,34 @@ def test_ray_replaybuffer(self, storage, sampler, writer):
if sampler is SamplerWithoutReplacement:
assert sample["a"].unique().numel() == sample.numel()

# class CustomCollectorCls(SyncDataCollector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove comments

# def __init__(self, create_env_fn, **kwargs):
# policy = lambda td: td.set("action", torch.full(td.shape, 2))
# super().__init__(create_env_fn, policy, **kwargs)

def test_ray_collector_policy_constructor(self):
n_collectors = 2
frames_per_batch = 50
total_frames = 300
env = CountingEnv

def policy_constructor():
return lambda td: td.set("action", torch.full(td.shape, 2))

collector = self.distributed_class()(
[env] * n_collectors,
collector_class=SyncDataCollector,
policy_factory=policy_constructor,
total_frames=total_frames,
frames_per_batch=frames_per_batch,
**self.distributed_kwargs(),
)
try:
for data in collector:
assert (data["action"] == 2).all()
finally:
collector.shutdown()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
117 changes: 85 additions & 32 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from multiprocessing.managers import SyncManager
from queue import Empty
from textwrap import indent
from typing import Any, Callable, Iterator, Sequence
from typing import Any, Callable, Iterator, Sequence, TypeVar

import numpy as np
import torch
Expand Down Expand Up @@ -86,6 +86,8 @@ def cudagraph_mark_step_begin():

_is_osx = sys.platform.startswith("darwin")

T = TypeVar("T")


class _Interruptor:
"""A class for managing the collection state of a process.
Expand Down Expand Up @@ -343,7 +345,15 @@ class SyncDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :arg:`policy_factory` should be used instead.

Keyword Args:
policy_factory (Callable[[], Callable], optional): a callable that returns
a policy instance. This is exclusive with the `policy` argument.

.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
total_frames (int): A keyword-only argument representing the total
Expand Down Expand Up @@ -515,6 +525,7 @@ def __init__(
policy: None
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
*,
policy_factory: Callable[[], Callable] | None = None,
frames_per_batch: int,
total_frames: int = -1,
device: DEVICE_TYPING = None,
Expand Down Expand Up @@ -558,8 +569,13 @@ def __init__(
env.update_kwargs(create_env_kwargs)

if policy is None:
if policy_factory is not None:
policy = policy_factory()
else:
policy = RandomPolicy(env.full_action_spec)
elif policy_factory is not None:
raise TypeError("policy_factory cannot be used with policy argument.")

policy = RandomPolicy(env.full_action_spec)
if trust_policy is None:
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
self.trust_policy = trust_policy
Expand Down Expand Up @@ -1429,17 +1445,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
self._iter = state_dict["iter"]

def __repr__(self) -> str:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ")
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
f"\n{policy_str},"
f"\n{td_out_str},"
f"\nexploration={self.exploration_type})"
)
return string
try:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(
f"td_out={getattr(self, '_final_rollout', None)}", 4 * " "
)
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
f"\n{policy_str},"
f"\n{td_out_str},"
f"\nexploration={self.exploration_type})"
)
return string
except AttributeError:
return f"{type(self).__name__}(not_init)"


class _MultiDataCollector(DataCollectorBase):
Expand Down Expand Up @@ -1469,7 +1490,18 @@ class _MultiDataCollector(DataCollectorBase):
- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :arg:`policy_factory` should be used instead.

Keyword Args:
policy_factory (Callable[[], Callable], optional): a callable that returns
a policy instance. This is exclusive with the `policy` argument.

.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

.. warning:: `policy_factory` is currently not compatible with multiprocessed data
collectors.

frames_per_batch (int): A keyword-only argument representing the
total number of elements in a batch.
total_frames (int, optional): A keyword-only argument representing the
Expand Down Expand Up @@ -1612,6 +1644,7 @@ def __init__(
policy: None
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
*,
policy_factory: Callable[[], Callable] | None = None,
frames_per_batch: int,
total_frames: int | None = -1,
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
Expand Down Expand Up @@ -1695,27 +1728,36 @@ def __init__(
self._get_weights_fn_dict = {}

if trust_policy is None:
trust_policy = isinstance(policy, CudaGraphModule)
trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
self.trust_policy = trust_policy

for policy_device, env_maker, env_maker_kwargs in zip(
self.policy_device, self.create_env_fn, self.create_env_kwargs
):
(policy_copy, get_weights_fn,) = self._get_policy_and_device(
policy=policy,
policy_device=policy_device,
env_maker=env_maker,
env_maker_kwargs=env_maker_kwargs,
)
if type(policy_copy) is not type(policy):
policy = policy_copy
weights = (
TensorDict.from_module(policy_copy)
if isinstance(policy_copy, nn.Module)
else TensorDict()
if policy_factory is not None and policy is not None:
raise TypeError("policy_factory and policy are mutually exclusive")
elif policy_factory is None:
for policy_device, env_maker, env_maker_kwargs in zip(
self.policy_device, self.create_env_fn, self.create_env_kwargs
):
(policy_copy, get_weights_fn,) = self._get_policy_and_device(
policy=policy,
policy_device=policy_device,
env_maker=env_maker,
env_maker_kwargs=env_maker_kwargs,
)
if type(policy_copy) is not type(policy):
policy = policy_copy
weights = (
TensorDict.from_module(policy_copy)
if isinstance(policy_copy, nn.Module)
else TensorDict()
)
self._policy_weights_dict[policy_device] = weights
self._get_weights_fn_dict[policy_device] = get_weights_fn
else:
# TODO
raise NotImplementedError(
"weight syncing is not supported for multiprocessed data collectors at the "
"moment."
)
self._policy_weights_dict[policy_device] = weights
self._get_weights_fn_dict[policy_device] = get_weights_fn
self.policy = policy

remainder = 0
Expand Down Expand Up @@ -2782,7 +2824,15 @@ class aSyncDataCollector(MultiaSyncDataCollector):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :arg:`policy_factory` should be used instead.

Keyword Args:
policy_factory (Callable[[], Callable], optional): a callable that returns
a policy instance. This is exclusive with the `policy` argument.

.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

frames_per_batch (int): A keyword-only argument representing the
total number of elements in a batch.
total_frames (int, optional): A keyword-only argument representing the
Expand Down Expand Up @@ -2888,8 +2938,10 @@ class aSyncDataCollector(MultiaSyncDataCollector):
def __init__(
self,
create_env_fn: Callable[[], EnvBase],
policy: None | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]),
policy: None
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
*,
policy_factory: Callable[[], Callable] | None = None,
frames_per_batch: int,
total_frames: int | None = -1,
device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
Expand All @@ -2914,6 +2966,7 @@ def __init__(
super().__init__(
create_env_fn=[create_env_fn],
policy=policy,
policy_factory=policy_factory,
total_frames=total_frames,
create_env_kwargs=[create_env_kwargs],
max_frames_per_traj=max_frames_per_traj,
Expand Down
13 changes: 11 additions & 2 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, OrderedDict

import torch.cuda
from tensordict import TensorDict
from tensordict import TensorDict, TensorDictBase
from torch import nn

from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE
Expand Down Expand Up @@ -270,7 +270,15 @@ class DistributedDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :arg:`policy_factory` should be used instead.

Keyword Args:
policy_factory (Callable[[], Callable], optional): a callable that returns
a policy instance. This is exclusive with the `policy` argument.

.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
total_frames (int): A keyword-only argument representing the total
Expand Down Expand Up @@ -406,8 +414,9 @@ class DistributedDataCollector(DataCollectorBase):
def __init__(
self,
create_env_fn,
policy,
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
*,
policy_factory: Callable[[], Callable] | None = None,
frames_per_batch: int,
total_frames: int = -1,
device: torch.device | list[torch.device] = None,
Expand Down
Loading
Loading