Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,21 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
total = 0
first_batch = None
last_batch = None
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
collector.update_policy_weights_()
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
assert (last_batch["action"] == 2).all(), last_batch["action"]
collector.shutdown()
assert total == total_frames
try:
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
collector.update_policy_weights_()
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
assert (last_batch["action"] == 2).all(), last_batch["action"]
assert total == total_frames
finally:
collector.shutdown()

@pytest.mark.parametrize("storage", [None, partial(LazyTensorStorage, 1000)])
@pytest.mark.parametrize(
Expand All @@ -587,6 +589,33 @@ def test_ray_replaybuffer(self, storage, sampler, writer):
if sampler is SamplerWithoutReplacement:
assert sample["a"].unique().numel() == sample.numel()

# class CustomCollectorCls(SyncDataCollector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove comments

# def __init__(self, create_env_fn, **kwargs):
# policy = lambda td: td.set("action", torch.full(td.shape, 2))
# super().__init__(create_env_fn, policy, **kwargs)

def test_ray_collector_policy_constructor(self):
n_collectors = 2
frames_per_batch = 50
total_frames = 300
env = CountingEnv

def policy_constructor():
return lambda td: td.set("action", torch.full(td.shape, 2))

collector = self.distributed_class()(
[env] * n_collectors,
collector_class=SyncDataCollector.from_policy_factory(policy_constructor),
total_frames=total_frames,
frames_per_batch=frames_per_batch,
**self.distributed_kwargs(),
)
try:
for data in collector:
assert (data["action"] == 2).all()
finally:
collector.shutdown()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
80 changes: 68 additions & 12 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from multiprocessing.managers import SyncManager
from queue import Empty
from textwrap import indent
from typing import Any, Callable, Iterator, Sequence
from typing import Any, Callable, Iterator, Sequence, TypeVar

import numpy as np
import torch
Expand Down Expand Up @@ -86,6 +86,8 @@ def cudagraph_mark_step_begin():

_is_osx = sys.platform.startswith("darwin")

T = TypeVar("T")


class _Interruptor:
"""A class for managing the collection state of a process.
Expand Down Expand Up @@ -315,6 +317,43 @@ def __len__(self) -> int:
return -(self.total_frames // -self.requested_frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")

@classmethod
def from_policy_factory(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikaylagawarecki

This is a bit convoluted.

We could also have one more kwarg in the constructors, but since there are many kwargs there already, and many collector subclasses, IDK if that's the most economical solution.

Happy to make this a kwargs that is exclusive with the policy arg if you think it's more suited

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is ok

Just a question -- do you think that giving the CustomCollectorCls a __name__ that depends on its cls would be helpful for debugging, or would that make it harder to find where the custom wrapper class was defined 😅

cls: type[T],
policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]],
) -> T:
"""Creates a custom subclass of Collector that instantiates a policy from a factory.

Args:
policy_factory (Callable[[], Callable[[TensorDictBase], TensorDictBase]]): a factory function that returns
a valid policy.

Example:
>>> import torch
>>>
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import GymEnv
>>>
>>> def factory():
... return lambda td: td.set("action", torch.ones((1)))
>>> cls = SyncDataCollector.from_policy_factory(factory)
>>> collector = cls(GymEnv("Pendulum-v1"), total_frames=10, frames_per_batch=5)
>>> for d in collector:
... assert (d["action"] == 1).all()

"""

class CustomCollectorCls(cls):
def __init__(self, *args, **kwargs):
if len(args) > 1 or "policy" in kwargs:
raise TypeError(
"The policy cannot be passed to the constructor of a collector class "
"that instantiates the policy from a factory."
)
super().__init__(*args, policy_factory(), **kwargs)

return CustomCollectorCls


@accept_remote_rref_udf_invocation
class SyncDataCollector(DataCollectorBase):
Expand Down Expand Up @@ -343,6 +382,10 @@ class SyncDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
Expand Down Expand Up @@ -1429,17 +1472,22 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
self._iter = state_dict["iter"]

def __repr__(self) -> str:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(f"td_out={getattr(self, '_final_rollout', None)}", 4 * " ")
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
f"\n{policy_str},"
f"\n{td_out_str},"
f"\nexploration={self.exploration_type})"
)
return string
try:
env_str = indent(f"env={self.env}", 4 * " ")
policy_str = indent(f"policy={self.policy}", 4 * " ")
td_out_str = indent(
f"td_out={getattr(self, '_final_rollout', None)}", 4 * " "
)
string = (
f"{self.__class__.__name__}("
f"\n{env_str},"
f"\n{policy_str},"
f"\n{td_out_str},"
f"\nexploration={self.exploration_type})"
)
return string
except AttributeError:
return f"{type(self).__name__}(not_init)"


class _MultiDataCollector(DataCollectorBase):
Expand Down Expand Up @@ -1469,6 +1517,10 @@ class _MultiDataCollector(DataCollectorBase):
- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the
total number of elements in a batch.
Expand Down Expand Up @@ -2782,6 +2834,10 @@ class aSyncDataCollector(MultiaSyncDataCollector):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the
total number of elements in a batch.
Expand Down
5 changes: 5 additions & 0 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ class DistributedDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.
The new collector subclass should then be passed as :attr:`collector_class` keyword argument.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
Expand Down
38 changes: 28 additions & 10 deletions torchrl/collectors/distributed/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class RayCollector(DataCollectorBase):
Args:
create_env_fn (Callable or List[Callabled]): list of Callables, each returning an
instance of :class:`~torchrl.envs.EnvBase`.
policy (Callable): Policy to be executed in the environment.
policy (Callable, optional): Policy to be executed in the environment.
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
If ``None`` is provided, the policy used will be a
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
Expand All @@ -144,6 +144,11 @@ class RayCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.
The new collector subclass should then be passed as :attr:`collector_class` keyword argument.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the
total number of elements in a batch.
Expand Down Expand Up @@ -291,7 +296,7 @@ class RayCollector(DataCollectorBase):
def __init__(
self,
create_env_fn: Callable | EnvBase | list[Callable] | list[EnvBase],
policy: Callable[[TensorDict], TensorDict],
policy: Callable[[TensorDict], TensorDict] | None = None,
*,
frames_per_batch: int,
total_frames: int = -1,
Expand Down Expand Up @@ -410,8 +415,16 @@ def check_list_length_consistency(*lists):
collector_class = MultiSyncDataCollector
elif collector_class == "single":
collector_class = SyncDataCollector
collector_class.as_remote = as_remote
collector_class.print_remote_collector_info = print_remote_collector_info
elif not isinstance(collector_class, type) or not issubclass(
collector_class, DataCollectorBase
):
raise TypeError(
"The collector_class must be an instance of DataCollectorBase."
)
if not hasattr(collector_class, "as_remote"):
collector_class.as_remote = as_remote
if not hasattr(collector_class, "print_remote_collector_info"):
collector_class.print_remote_collector_info = print_remote_collector_info

self.replay_buffer = replay_buffer
self._local_policy = policy
Expand Down Expand Up @@ -540,11 +553,12 @@ def policy_device(self, value):
self._policy_device = [value] * self.num_collectors

@staticmethod
def _make_collector(cls, env_maker, policy, other_params):
def _make_collector(cls, *, env_maker, policy, other_params):
"""Create a single collector instance."""
if policy is not None:
other_params["policy"] = policy
collector = cls(
env_maker,
policy,
total_frames=-1,
**other_params,
)
Expand All @@ -565,11 +579,15 @@ def add_collectors(
cls = self.collector_class.as_remote(remote_config).remote
collector = self._make_collector(
cls,
[env_maker] * num_envs
if self.collector_class is not SyncDataCollector
env_maker=[env_maker] * num_envs
if num_envs > 1
or (
isinstance(self.collector_class, type)
and not issubclass(self.collector_class, SyncDataCollector)
)
else env_maker,
policy,
other_params,
policy=policy,
other_params=other_params,
)
self._remote_collectors.append(collector)

Expand Down
5 changes: 5 additions & 0 deletions torchrl/collectors/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ class RPCDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.
The new collector subclass should then be passed as :attr:`collector_class` keyword argument.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
Expand Down
5 changes: 5 additions & 0 deletions torchrl/collectors/distributed/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ class DistributedSyncDataCollector(DataCollectorBase):

- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
pickled directly), the :meth:`~.from_policy_factory` method should be used to subclass the collector
and create a version that instantiates a specific version of the policy on demand.
The new collector subclass should then be passed as :attr:`collector_class` keyword argument.

Keyword Args:
frames_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
Expand Down
Loading