diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst
index 50d7e35ed52..61109bca4f8 100644
--- a/docs/source/reference/collectors.rst
+++ b/docs/source/reference/collectors.rst
@@ -117,6 +117,76 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
Policy copy decision tree in Collectors.
+Weight Synchronization in Distributed Environments
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
+latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
+mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
+
+Local and Remote Weight Updaters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for
+implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
+
+- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on
+ the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
+ different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
+ It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
+ situations where the server decides when to update the worker policies).
+- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to
+ remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
+ the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
+ devices or processes.
+
+Extending the Updater Classes
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
+This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
+setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
+transformed, and applied, ensuring seamless integration with their existing infrastructure.
+
+Default Implementations
+~~~~~~~~~~~~~~~~~~~~~~~
+
+For common scenarios, the API provides default implementations of these updaters, such as
+:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
+:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and
+:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`.
+These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
+distributed systems.
+
+Practical Considerations
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+When designing a system that leverages this API, consider the following:
+
+- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
+ implementation accounts for potential delays and optimizes data transfer where possible.
+- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
+ the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
+ suboptimal policy performance.
+- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
+ overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
+
+By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
+scenarios, ensuring that their policies remain up-to-date and performant.
+
+.. currentmodule:: torchrl.collectors
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template.rst
+
+ LocalWeightUpdaterBase
+ RemoteWeightUpdaterBase
+ VanillaLocalWeightUpdater
+ MultiProcessedRemoteWeightUpdate
+ RayRemoteWeightUpdater
+ DistributedRemoteWeightUpdater
+ RPCRemoteWeightUpdater
Collectors and replay buffers interoperability
----------------------------------------------
diff --git a/test/test_distributed.py b/test/test_distributed.py
index c49596bb26f..41797ae066f 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -390,25 +390,29 @@ def _test_distributed_collector_updatepolicy(
update_interval=update_interval,
**cls.distributed_kwargs(),
)
- total = 0
- first_batch = None
- last_batch = None
- for i, data in enumerate(collector):
- total += data.numel()
- assert data.numel() == frames_per_batch
- if i == 0:
- first_batch = data
- policy.weight.data += 1
- elif total == total_frames - frames_per_batch:
- last_batch = data
- assert (first_batch["action"] == 1).all(), first_batch["action"]
- if update_interval == 1:
- assert (last_batch["action"] == 2).all(), last_batch["action"]
- else:
- assert (last_batch["action"] == 1).all(), last_batch["action"]
- collector.shutdown()
- assert total == total_frames
- queue.put("passed")
+ try:
+
+ total = 0
+ first_batch = None
+ last_batch = None
+ for i, data in enumerate(collector):
+ total += data.numel()
+ assert data.numel() == frames_per_batch
+ if i == 0:
+ first_batch = data
+ policy.weight.data += 1
+ elif total == total_frames - frames_per_batch:
+ last_batch = data
+ assert (first_batch["action"] == 1).all(), first_batch["action"]
+ if update_interval == 1:
+ assert (last_batch["action"] == 2).all(), last_batch["action"]
+ else:
+ assert (last_batch["action"] == 1).all(), last_batch["action"]
+ assert total == total_frames
+ queue.put("passed")
+ finally:
+ collector.shutdown()
+ queue.put("not passed")
@pytest.mark.parametrize(
"collector_class",
@@ -490,12 +494,14 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200):
sync=sync,
**self.distributed_kwargs(),
)
- total = 0
- for data in collector:
- total += data.numel()
- assert data.numel() == frames_per_batch
- collector.shutdown()
- assert total == 200
+ try:
+ total = 0
+ for data in collector:
+ total += data.numel()
+ assert data.numel() == frames_per_batch
+ assert total == 200
+ finally:
+ collector.shutdown()
@pytest.mark.parametrize(
"collector_class",
@@ -517,12 +523,14 @@ def test_distributed_collector_class(self, collector_class):
frames_per_batch=frames_per_batch,
**self.distributed_kwargs(),
)
- total = 0
- for data in collector:
- total += data.numel()
- assert data.numel() == frames_per_batch
- collector.shutdown()
- assert total == 200
+ try:
+ total = 0
+ for data in collector:
+ total += data.numel()
+ assert data.numel() == frames_per_batch
+ assert total == 200
+ finally:
+ collector.shutdown()
@pytest.mark.parametrize(
"collector_class",
diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py
index 2d40522bb07..b3170e3631a 100644
--- a/torchrl/collectors/__init__.py
+++ b/torchrl/collectors/__init__.py
@@ -12,9 +12,21 @@
MultiSyncDataCollector,
SyncDataCollector,
)
+from .weight_update import (
+ LocalWeightUpdaterBase,
+ MultiProcessedRemoteWeightUpdate,
+ RayRemoteWeightUpdater,
+ RemoteWeightUpdaterBase,
+ VanillaLocalWeightUpdater,
+)
__all__ = [
"RandomPolicy",
+ "LocalWeightUpdaterBase",
+ "RemoteWeightUpdaterBase",
+ "VanillaLocalWeightUpdater",
+ "RayRemoteWeightUpdater",
+ "MultiProcessedRemoteWeightUpdate",
"aSyncDataCollector",
"DataCollectorBase",
"MultiaSyncDataCollector",
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 58f67229413..3ff35dbb560 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -26,12 +26,7 @@
import numpy as np
import torch
import torch.nn as nn
-from tensordict import (
- LazyStackedTensorDict,
- TensorDict,
- TensorDictBase,
- TensorDictParams,
-)
+from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.base import NO_DEFAULT
from tensordict.nn import CudaGraphModule, TensorDictModule
from tensordict.utils import Buffer
@@ -53,6 +48,12 @@
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
+from torchrl.collectors.weight_update import (
+ LocalWeightUpdaterBase,
+ MultiProcessedRemoteWeightUpdate,
+ RemoteWeightUpdaterBase,
+ VanillaLocalWeightUpdater,
+)
from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
@@ -151,6 +152,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
+ local_weights_updater: LocalWeightUpdaterBase | None = None
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None
def _get_policy_and_device(
self,
@@ -242,19 +245,50 @@ def map_weight(
return policy, get_original_weights
def update_policy_weights_(
- self, policy_weights: TensorDictBase | None = None
+ self,
+ policy_weights: TensorDictBase | None = None,
+ *,
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
+ **kwargs,
) -> None:
- """Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
+ """Updates the policy weights for the data collector, accommodating both local and remote execution contexts.
+
+ This method ensures that the policy weights used by the data collector are synchronized with the latest
+ trained weights. It supports both local and remote weight updates, depending on the configuration of the
+ data collector. The local (download) update is performed before the remote (upload) update, such that weights
+ can be transferred to the children workers from a server.
Args:
- policy_weights (TensorDictBase, optional): if provided, a TensorDict containing
- the weights of the policy to be used for the udpdate.
+ policy_weights (TensorDictBase, optional): A TensorDict containing the weights of the policy to be used
+ for the update. If not provided, the method will attempt to fetch the weights using the configured
+ weight updater.
+ worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
+ workers that need to be updated. This is relevant when using a remote weights updater, which must
+ be specified during the data collector's initialization. If `worker_ids` is provided without a
+ configured remote weights updater, a TypeError will be raised.
+
+ Raises:
+ TypeError: If `worker_ids` is provided but no `remote_weights_updater` is configured.
+
+ .. note::
+
+ - The method first attempts to update weights locally using `local_weights_updater`, if available.
+ - If a `remote_weights_updater` is configured, it will be used to update the specified remote workers.
+ - Users can extend the `LocalWeightUpdaterBase` and `RemoteWeightUpdaterBase` classes to customize
+ the weight update logic for specific use cases. This method should not be overwritten.
+
+ .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
+ :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
"""
- if policy_weights is not None:
- self.policy_weights.data.update_(policy_weights)
- elif self.get_weights_fn is not None:
- self.policy_weights.data.update_(self.get_weights_fn())
+ if self.local_weights_updater is not None:
+ self.local_weights_updater(policy_weights, **kwargs)
+ if self.remote_weights_updater is not None:
+ self.remote_weights_updater(policy_weights, worker_ids=worker_ids, **kwargs)
+ elif worker_ids is not None:
+ raise TypeError(
+ "worker_ids was passed but remote_weights_updater was None."
+ )
def __iter__(self) -> Iterator[TensorDictBase]:
try:
@@ -459,6 +493,13 @@ class SyncDataCollector(DataCollectorBase):
or `ManiSkills `_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
+ local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on the local inference worker.
+ If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default,
+ which directly fetches and applies the weights from the server.
+ remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on remote inference workers.
+ This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
@@ -549,6 +590,8 @@ def __init__(
compile_policy: bool | dict[str, Any] | None = None,
cudagraph_policy: bool | dict[str, Any] | None = None,
no_cuda_sync: bool = False,
+ local_weights_updater: LocalWeightUpdaterBase | None = None,
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -788,6 +831,14 @@ def __init__(
self._frames = 0
self._iter = -1
+ if local_weights_updater is None:
+ local_weights_updater = VanillaLocalWeightUpdater(
+ weight_getter=self.get_weights_fn, policy_weights=self.policy_weights
+ )
+
+ self.local_weights_updater = local_weights_updater
+ self.remote_weights_updater = remote_weights_updater
+
@property
def _traj_pool(self):
pool = getattr(self, "_traj_pool_val", None)
@@ -998,9 +1049,14 @@ def next(self):
# for RPC
def update_policy_weights_(
- self, policy_weights: TensorDictBase | None = None
+ self,
+ policy_weights: TensorDictBase | None = None,
+ *,
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
- super().update_policy_weights_(policy_weights)
+ super().update_policy_weights_(
+ policy_weights=policy_weights, worker_ids=worker_ids
+ )
def set_seed(self, seed: int, static_seed: bool = False) -> int:
"""Sets the seeds of the environments stored in the DataCollector.
@@ -1635,6 +1691,13 @@ class _MultiDataCollector(DataCollectorBase):
or `ManiSkills `_) cuda synchronization may cause unexpected
crashes.
Defaults to ``False``.
+ local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on each local inference worker.
+ If not provided, left unused.
+ remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on remote inference workers.
+ If not provided, a :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate` will be used by default,
+ which handles weight synchronization across multiple processes.
"""
@@ -1672,6 +1735,8 @@ def __init__(
compile_policy: bool | dict[str, Any] | None = None,
cudagraph_policy: bool | dict[str, Any] | None = None,
no_cuda_sync: bool = False,
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None,
+ local_weights_updater: LocalWeightUpdaterBase | None = None,
):
self.closed = True
self.num_workers = len(create_env_fn)
@@ -1725,7 +1790,6 @@ def __init__(
replay_buffer.share()
self._policy_weights_dict = {}
- self._get_weights_fn_dict = {}
if trust_policy is None:
trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
@@ -1751,13 +1815,21 @@ def __init__(
else TensorDict()
)
self._policy_weights_dict[policy_device] = weights
- self._get_weights_fn_dict[policy_device] = get_weights_fn
- else:
+ self._get_weights_fn = get_weights_fn
+ if remote_weights_updater is None:
+ remote_weights_updater = MultiProcessedRemoteWeightUpdate(
+ get_server_weights=self._get_weights_fn,
+ policy_weights=self._policy_weights_dict,
+ )
+ elif remote_weights_updater is None:
# TODO
raise NotImplementedError(
- "weight syncing is not supported for multiprocessed data collectors at the "
- "moment."
+ "remote_weights_updater cannot be None when policy_factory is provided."
)
+
+ self.remote_weights_updater = remote_weights_updater
+ self.local_weights_updater = local_weights_updater
+
self.policy = policy
remainder = 0
@@ -1913,21 +1985,6 @@ def _get_devices(
def frames_per_batch_worker(self):
raise NotImplementedError
- def update_policy_weights_(self, policy_weights=None) -> None:
- if isinstance(policy_weights, TensorDictParams):
- policy_weights = policy_weights.data
- for _device in self._policy_weights_dict:
- if policy_weights is not None:
- self._policy_weights_dict[_device].data.update_(policy_weights)
- elif self._get_weights_fn_dict[_device] is not None:
- original_weights = self._get_weights_fn_dict[_device]()
- if original_weights is None:
- # if the weights match in identity, we can spare a call to update_
- continue
- if isinstance(original_weights, TensorDictParams):
- original_weights = original_weights.data
- self._policy_weights_dict[_device].data.update_(original_weights)
-
@property
def _queue_len(self) -> int:
raise NotImplementedError
@@ -1958,6 +2015,10 @@ def _run_processes(self) -> None:
policy_device = self.policy_device[i]
storing_device = self.storing_device[i]
env_device = self.env_device[i]
+ # We take the weights, the policy, and locally dispatch the weights to the policy
+ # while we send the policy to the remote process.
+ # This makes sure that a given set of shared weights for a given device are
+ # shared for all policies that rely on that device.
policy = self.policy
policy_weights = self._policy_weights_dict[policy_device]
if policy is not None and policy_weights is not None:
@@ -2312,9 +2373,14 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
# for RPC
def update_policy_weights_(
- self, policy_weights: TensorDictBase | None = None
+ self,
+ policy_weights: TensorDictBase | None = None,
+ *,
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
- super().update_policy_weights_(policy_weights)
+ super().update_policy_weights_(
+ policy_weights=policy_weights, worker_ids=worker_ids
+ )
@property
def frames_per_batch_worker(self):
@@ -2676,9 +2742,14 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
# for RPC
def update_policy_weights_(
- self, policy_weights: TensorDictBase | None = None
+ self,
+ policy_weights: TensorDictBase | None = None,
+ *,
+ worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
- super().update_policy_weights_(policy_weights)
+ super().update_policy_weights_(
+ policy_weights=policy_weights, worker_ids=worker_ids
+ )
@property
def frames_per_batch_worker(self):
diff --git a/torchrl/collectors/distributed/__init__.py b/torchrl/collectors/distributed/__init__.py
index c28122c6c6a..79a52b1698a 100644
--- a/torchrl/collectors/distributed/__init__.py
+++ b/torchrl/collectors/distributed/__init__.py
@@ -3,7 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .generic import DEFAULT_SLURM_CONF, DistributedDataCollector
+from .generic import (
+ DEFAULT_SLURM_CONF,
+ DistributedDataCollector,
+ DistributedRemoteWeightUpdater,
+)
from .ray import RayCollector
from .rpc import RPCDataCollector
from .sync import DistributedSyncDataCollector
@@ -12,8 +16,10 @@
__all__ = [
"DEFAULT_SLURM_CONF",
"DistributedDataCollector",
- "RayCollector",
- "RPCDataCollector",
+ "DistributedRemoteWeightUpdater",
"DistributedSyncDataCollector",
+ "RPCDataCollector",
+ "RPCDataCollector",
+ "RayCollector",
"submitit_delayed_launcher",
]
diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py
index 00ea799d55f..79fdee553c3 100644
--- a/torchrl/collectors/distributed/generic.py
+++ b/torchrl/collectors/distributed/generic.py
@@ -31,6 +31,10 @@
TCP_PORT,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
+from torchrl.collectors.weight_update import (
+ LocalWeightUpdaterBase,
+ RemoteWeightUpdaterBase,
+)
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -227,7 +231,7 @@ def _run_collector(
# been updated
policy_weights.irecv(0)
# the policy has been updated: we can simply update the weights
- collector.update_policy_weights_(policy_weights)
+ collector.update_policy_weights_(policy_weights=policy_weights)
_store.set(f"NODE_{rank}_out", b"updated")
elif instruction.startswith(b"seeding"):
seed = int(instruction.split(b"seeding_"))
@@ -407,6 +411,15 @@ class DistributedDataCollector(DataCollectorBase):
to learn more.
Defaults to ``"submitit"``.
tcp_port (int, optional): the TCP port to be used. Defaults to 10003.
+ local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on the local inference worker.
+ This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it
+ focuses on distributed environments.
+ remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on distributed inference workers.
+ If not provided, a :class:`~torchrl.collectors.distributed.DistributedRemoteWeightUpdater` will be used by
+ default, which handles weight synchronization across distributed workers.
+
"""
_VERBOSE = VERBOSE # for debugging
@@ -439,6 +452,8 @@ def __init__(
max_weight_update_interval: int = -1,
launcher: str = "submitit",
tcp_port: int = None,
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None,
+ local_weights_updater: LocalWeightUpdaterBase | None = None,
):
if collector_class == "async":
@@ -453,6 +468,13 @@ def __init__(
if isinstance(policy, nn.Module):
policy_weights = TensorDict.from_module(policy)
policy_weights = policy_weights.data.lock_()
+ elif policy_factory is not None:
+ policy_weights = None
+ if remote_weights_updater is None:
+ raise RuntimeError(
+ "remote_weights_updater must be passed along with "
+ "a policy_factory."
+ )
else:
warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict(lock=True)
@@ -534,6 +556,15 @@ def __init__(
self._init_workers()
self._make_container()
+ if remote_weights_updater is None:
+ remote_weights_updater = DistributedRemoteWeightUpdater(
+ store=self._store,
+ policy_weights=self.policy_weights,
+ num_workers=self.num_workers,
+ sync=self._sync,
+ )
+ self.remote_weights_updater = remote_weights_updater
+ self.local_weights_updater = local_weights_updater
@property
def device(self) -> list[torch.device]:
@@ -817,7 +848,9 @@ def _iterator_dist(self):
self._batches_since_weight_update[j]
> self.max_weight_update_interval
):
- self.update_policy_weights_(rank)
+ self.update_policy_weights_(
+ policy_weights=None, worker_ids=rank
+ )
for i in range(self.num_workers):
rank = i + 1
@@ -866,7 +899,7 @@ def _next_async(self, total_frames, trackers):
_tracker.wait()
data = self._tensordict_out[i].clone()
if self.update_after_each_batch:
- self.update_policy_weights_(rank)
+ self.update_policy_weights_(worker_ids=rank)
total_frames += data.numel()
if total_frames < self.total_frames:
if self._VERBOSE:
@@ -880,31 +913,6 @@ def _next_async(self, total_frames, trackers):
break
return data, total_frames
- def update_policy_weights_(self, worker_rank=None) -> None:
- """Updates the weights of the worker nodes.
-
- Args:
- worker_rank (int, optional): if provided, only this worker weights
- will be updated.
- """
- if worker_rank is not None and worker_rank < 1:
- raise RuntimeError("worker_rank must be greater than 1")
- workers = range(self.num_workers) if worker_rank is None else [worker_rank - 1]
- for i in workers:
- rank = i + 1
- if self._VERBOSE:
- torchrl_logger.info(f"updating weights of {rank}")
- self._store.set(f"NODE_{rank}_in", b"update_weights")
- if self._sync:
- self.policy_weights.send(rank)
- else:
- self.policy_weights.isend(rank)
- self._batches_since_weight_update[i] = 0
- status = self._store.get(f"NODE_{rank}_out")
- if status != b"updated":
- raise RuntimeError(f"Expected 'updated' but got status {status}.")
- self._store.delete_key(f"NODE_{rank}_out")
-
def set_seed(self, seed: int, static_seed: bool = False) -> int:
for i in range(self.num_workers):
rank = i + 1
@@ -950,3 +958,98 @@ def shutdown(self):
pass
if self._VERBOSE:
torchrl_logger.info("collector shut down")
+
+
+class DistributedRemoteWeightUpdater(RemoteWeightUpdaterBase):
+ """A remote weight updater for synchronizing policy weights across distributed workers.
+
+ The `DistributedRemoteWeightUpdater` class provides a mechanism for updating the weights
+ of a policy across distributed inference workers. It is designed to work with the
+ :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights.
+ This class is typically used in distributed data collection scenarios where multiple workers
+ need to be kept in sync with the central policy weights.
+
+ Args:
+ store (dict[str, str]): A dictionary-like store used for communication between the server
+ and the distributed workers.
+ policy_weights (TensorDictBase): The current weights of the policy that need to be distributed
+ to the workers.
+ num_workers (int): The number of distributed workers that will receive the updated policy weights.
+ sync (bool): if ``True``, the sync happens synchronously (the server waits for the worker to have completed
+ the update to restart the run).
+
+ Methods:
+ update_weights: Updates the weights on specified or all distributed workers.
+ all_worker_ids: Returns a list of all worker identifiers (not implemented in this class).
+ _sync_weights_with_worker: Synchronizes the server weights with a specific worker (not implemented).
+ _get_server_weights: Retrieves the latest weights from the server (not implemented).
+ _maybe_map_weights: Optionally maps server weights before distribution (not implemented).
+
+ .. note::
+ This class assumes that the server weights can be directly applied to the distributed workers
+ without any additional processing. If your use case requires more complex weight mapping or
+ synchronization logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation.
+
+ Raises:
+ RuntimeError: If the worker rank is less than 1 or if the status returned by the store is not "updated".
+
+ .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and
+ :class:`~torchrl.collectors.distributed.DistributedDataCollector`.
+
+ """
+
+ _VERBOSE = True
+
+ def __init__(
+ self,
+ store: dict[str, str],
+ policy_weights: TensorDictBase,
+ num_workers: int,
+ sync: bool,
+ ):
+ self._store = store
+ self.policy_weights = policy_weights
+ self.num_workers = num_workers
+ self._sync = sync
+ self._batches_since_weight_update = [0 for _ in range(self.num_workers)]
+
+ def _sync_weights_with_worker(
+ self, worker_id: int | torch.device, server_weights: TensorDictBase
+ ) -> TensorDictBase:
+ raise NotImplementedError
+
+ def _get_server_weights(self) -> TensorDictBase:
+ raise NotImplementedError
+
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ raise NotImplementedError
+
+ def all_worker_ids(self) -> list[int] | list[torch.device]:
+ raise NotImplementedError
+
+ def update_weights(
+ self,
+ weights: TensorDictBase | None = None,
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
+ ):
+ worker_rank = worker_ids
+ if isinstance(worker_ids, int):
+ if worker_rank is not None and worker_rank < 1:
+ raise RuntimeError("worker_rank must be greater than 1")
+ worker_rank = [worker_rank - 1]
+ workers = range(self.num_workers) if worker_rank is None else worker_rank
+ weights = self.policy_weights if weights is None else weights
+ for i in workers:
+ rank = i + 1
+ if self._VERBOSE:
+ torchrl_logger.info(f"updating weights of {rank}")
+ self._store.set(f"NODE_{rank}_in", b"update_weights")
+ if self._sync:
+ weights.send(rank)
+ else:
+ weights.isend(rank)
+ self._batches_since_weight_update[i] = 0
+ status = self._store.get(f"NODE_{rank}_out")
+ if status != b"updated":
+ raise RuntimeError(f"Expected 'updated' but got status {status}.")
+ self._store.delete_key(f"NODE_{rank}_out")
diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py
index 279ef9e4aae..eb43abf7428 100644
--- a/torchrl/collectors/distributed/ray.py
+++ b/torchrl/collectors/distributed/ray.py
@@ -22,6 +22,11 @@
SyncDataCollector,
)
from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories
+from torchrl.collectors.weight_update import (
+ LocalWeightUpdaterBase,
+ RayRemoteWeightUpdater,
+ RemoteWeightUpdaterBase,
+)
from torchrl.data import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -255,9 +260,10 @@ class RayCollector(DataCollectorBase):
all workers will see their weights updated. For ``sync=False``,
only the worker from which the data has been gathered will be
updated.
+ This is equivalent to `max_weight_update_interval=0`.
Defaults to ``False``, i.e. updates have to be executed manually
through
- ``torchrl.collectors.distributed.RayDistributedCollector.update_policy_weights_()``
+ :meth:`torchrl.collectors.DataCollector.update_policy_weights_`
max_weight_update_interval (int, optional): the maximum number of
batches that can be collected before the policy weights of a worker
is updated.
@@ -271,6 +277,13 @@ class RayCollector(DataCollectorBase):
.. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a
:class:`~torchrl.data.RayReplayBuffer` instance should be used here.
+ local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on the local inference worker.
+ This is typically not used in :class:`~torchrl.collectors.RayCollector` as it focuses on distributed environments.
+ remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray.
+ If not provided, a :class:`~torchrl.collectors.RayRemoteWeightUpdater` will be used by default, leveraging
+ Ray's distributed capabilities.
Examples:
>>> from torch import nn
@@ -329,6 +342,8 @@ def __init__(
update_after_each_batch=False,
max_weight_update_interval=-1,
replay_buffer: ReplayBuffer = None,
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None,
+ local_weights_updater: LocalWeightUpdaterBase | None = None,
):
self.frames_per_batch = frames_per_batch
if remote_configs is None:
@@ -441,16 +456,19 @@ def check_list_length_consistency(*lists):
policy_weights = TensorDict.from_module(self._local_policy)
policy_weights = policy_weights.data.lock_()
else:
- warnings.warn(_NON_NN_POLICY_WEIGHTS)
policy_weights = TensorDict(lock=True)
+ if remote_weights_updater is None:
+ warnings.warn(_NON_NN_POLICY_WEIGHTS)
self.policy_weights = policy_weights
self.collector_class = collector_class
self.collected_frames = 0
self.split_trajs = split_trajs
self.total_frames = total_frames
self.num_collectors = num_collectors
+
self.update_after_each_batch = update_after_each_batch
self.max_weight_update_interval = max_weight_update_interval
+
self.collector_kwargs = (
collector_kwargs if collector_kwargs is not None else [{}]
)
@@ -507,10 +525,18 @@ def check_list_length_consistency(*lists):
collector_kwargs,
remote_configs,
)
+ if remote_weights_updater is None:
+ remote_weights_updater = RayRemoteWeightUpdater(
+ policy_weights=policy_weights,
+ remote_collectors=self.remote_collectors,
+ max_interval=self.max_weight_update_interval,
+ )
+ self.remote_weights_updater = remote_weights_updater
+ self.local_weights_updater = local_weights_updater
# Print info of all remote workers
pending_samples = [
- e.print_remote_collector_info.remote() for e in self.remote_collectors()
+ e.print_remote_collector_info.remote() for e in self.remote_collectors
]
ray.wait(pending_samples)
@@ -605,6 +631,7 @@ def local_policy(self):
"""Returns local collector."""
return self._local_policy
+ @property
def remote_collectors(self):
"""Returns list of remote collectors."""
return self._remote_collectors
@@ -612,7 +639,7 @@ def remote_collectors(self):
def stop_remote_collectors(self):
"""Stops all remote collectors."""
for _ in range(len(self._remote_collectors)):
- collector = self.remote_collectors().pop()
+ collector = self.remote_collectors.pop()
# collector.__ray_terminate__.remote() # This will kill the actor but let pending tasks finish
ray.kill(
collector
@@ -650,14 +677,11 @@ def proc(data):
def _sync_iterator(self) -> Iterator[TensorDictBase]:
"""Collects one data batch per remote collector in each iteration."""
while self.collected_frames < self.total_frames:
- if self.update_after_each_batch:
+ if self.update_after_each_batch or self.max_weight_update_interval > -1:
self.update_policy_weights_()
- else:
- for j in range(self.num_collectors):
- self._batches_since_weight_update[j] += 1
# Ask for batches to all remote workers.
- pending_tasks = [e.next.remote() for e in self.remote_collectors()]
+ pending_tasks = [e.next.remote() for e in self.remote_collectors]
# Wait for all rollouts
samples_ready = []
@@ -683,15 +707,6 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]:
yield out_td
- if self.max_weight_update_interval > -1:
- for j in range(self.num_collectors):
- rank = j + 1
- if (
- self._batches_since_weight_update[j]
- > self.max_weight_update_interval
- ):
- self.update_policy_weights_(rank)
-
if self._task is None:
self.shutdown()
@@ -720,19 +735,19 @@ async def async_shutdown(self):
def _async_iterator(self) -> Iterator[TensorDictBase]:
"""Collects a data batch from a single remote collector in each iteration."""
pending_tasks = {}
- for index, collector in enumerate(self.remote_collectors()):
+ for index, collector in enumerate(self.remote_collectors):
future = collector.next.remote()
pending_tasks[future] = index
while self.collected_frames < self.total_frames:
- if not len(list(pending_tasks.keys())) == len(self.remote_collectors()):
+ if not len(list(pending_tasks.keys())) == len(self.remote_collectors):
raise RuntimeError("Missing pending tasks, something went wrong")
# Wait for first worker to finish
wait_results = ray.wait(list(pending_tasks.keys()))
future = wait_results[0][0]
collector_index = pending_tasks.pop(future)
- collector = self.remote_collectors()[collector_index]
+ collector = self.remote_collectors[collector_index]
# Retrieve single rollouts
out_td = ray.get(future)
@@ -743,18 +758,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
yield out_td
- for j in range(self.num_collectors):
- self._batches_since_weight_update[j] += 1
- if self.update_after_each_batch:
- self.update_policy_weights_(worker_rank=collector_index + 1)
- elif self.max_weight_update_interval > -1:
- for j in range(self.num_collectors):
- rank = j + 1
- if (
- self._batches_since_weight_update[j]
- > self.max_weight_update_interval
- ):
- self.update_policy_weights_(rank)
+ if self.update_after_each_batch or self.max_weight_update_interval > -1:
+ self.update_policy_weights_(worker_ids=collector_index + 1)
# Schedule a new collection task
future = collector.next.remote()
@@ -773,36 +778,16 @@ def _async_iterator(self) -> Iterator[TensorDictBase]:
if self._task is None:
self.shutdown()
- def update_policy_weights_(self, worker_rank=None) -> None:
- """Updates the weights of the worker nodes.
-
- Args:
- worker_rank (int, optional): if provided, only this worker weights
- will be updated.
- """
- # Update agent weights
- policy_weights_local_collector_ref = ray.put(self.policy_weights.detach())
-
- if worker_rank is None:
- for index, e in enumerate(self.remote_collectors()):
- e.update_policy_weights_.remote(policy_weights_local_collector_ref)
- self._batches_since_weight_update[index] = 0
- else:
- self.remote_collectors()[worker_rank - 1].update_policy_weights_.remote(
- policy_weights_local_collector_ref
- )
- self._batches_since_weight_update[worker_rank - 1] = 0
-
def set_seed(self, seed: int, static_seed: bool = False) -> list[int]:
"""Calls parent method for each remote collector iteratively and returns final seed."""
- for collector in self.remote_collectors():
+ for collector in self.remote_collectors:
seed = ray.get(object_refs=collector.set_seed.remote(seed, static_seed))
return seed
def state_dict(self) -> list[OrderedDict]:
"""Calls parent method for each remote collector and returns a list of results."""
futures = [
- collector.state_dict.remote() for collector in self.remote_collectors()
+ collector.state_dict.remote() for collector in self.remote_collectors
]
results = ray.get(object_refs=futures)
return results
@@ -812,8 +797,8 @@ def load_state_dict(self, state_dict: OrderedDict | list[OrderedDict]) -> None:
if isinstance(state_dict, OrderedDict):
state_dicts = [state_dict]
if len(state_dict) == 1:
- state_dicts = state_dict * len(self.remote_collectors())
- for collector, state_dict in zip(self.remote_collectors(), state_dicts):
+ state_dicts = state_dict * len(self.remote_collectors)
+ for collector, state_dict in zip(self.remote_collectors, state_dicts):
collector.load_state_dict.remote(state_dict)
def shutdown(self):
diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py
index f205325ed30..7d198375251 100644
--- a/torchrl/collectors/distributed/rpc.py
+++ b/torchrl/collectors/distributed/rpc.py
@@ -47,6 +47,10 @@
MultiSyncDataCollector,
SyncDataCollector,
)
+from torchrl.collectors.weight_update import (
+ LocalWeightUpdaterBase,
+ RemoteWeightUpdaterBase,
+)
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -262,6 +266,14 @@ class RPCDataCollector(DataCollectorBase):
device used to pass data to main.
tensorpipe_options (dict, optional): a dictionary of keyword argument
to pass to :class:`torch.distributed.rpc.TensorPipeRpcBackendOption`.
+ local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on the local inference worker. This is
+ typically not used in :class:`~torchrl.collectors.distrbibuted.RPCDataCollector` as it focuses on
+ distributed environments.
+ remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase`
+ or its subclass, responsible for updating the policy weights on remote inference workers using RPC.
+ If not provided, an :class:`~torchrl.collectors.distributed.RPCRemoteWeightUpdater` will be used by default, which
+ handles weight synchronization via RPC.
"""
@@ -296,6 +308,8 @@ def __init__(
tcp_port=None,
visible_devices=None,
tensorpipe_options=None,
+ remote_weights_updater: RemoteWeightUpdaterBase | None = None,
+ local_weights_updater: LocalWeightUpdaterBase | None = None,
):
if collector_class == "async":
collector_class = MultiaSyncDataCollector
@@ -393,6 +407,16 @@ def __init__(
tensorpipe_options
)
self._init()
+ if remote_weights_updater is None:
+ remote_weights_updater = RPCRemoteWeightUpdater(
+ collector_infos=self.collector_infos,
+ collector_class=self.collector_class,
+ collector_rrefs=self.collector_rrefs,
+ policy_weights=self.policy_weights,
+ num_workers=self.num_workers,
+ )
+ self.local_weights_updater = local_weights_updater
+ self.remote_weights_updater = remote_weights_updater
@property
def device(self) -> list[torch.device]:
@@ -655,7 +679,11 @@ def iterator(self):
self._batches_since_weight_update[j]
> self.max_weight_update_interval
):
- self.update_policy_weights_([j], wait=False)
+ if self._VERBOSE:
+ torchrl_logger.info(
+ f"Updating policy of worker {j} with wait=False"
+ )
+ self.update_policy_weights_(worker_ids=[j], wait=False)
elif self.max_weight_update_interval > -1:
ranks = [
1
@@ -663,29 +691,11 @@ def iterator(self):
if self._batches_since_weight_update[j]
> self.max_weight_update_interval
]
- self.update_policy_weights_(ranks, wait=True)
-
- def update_policy_weights_(self, workers=None, wait=True) -> None:
- if workers is None:
- workers = list(range(self.num_workers))
- futures = []
- for i in workers:
- if self._VERBOSE:
- torchrl_logger.info(f"calling update on worker {i}")
- futures.append(
- rpc.rpc_async(
- self.collector_infos[i],
- self.collector_class.update_policy_weights_,
- args=(self.collector_rrefs[i], self.policy_weights.detach()),
- )
- )
- if wait:
- for i in workers:
if self._VERBOSE:
- torchrl_logger.info(f"waiting for worker {i}")
- futures[i].wait()
- if self._VERBOSE:
- torchrl_logger.info("got it!")
+ torchrl_logger.info(
+ f"Updating policy of workers {ranks} with wait=True"
+ )
+ self.update_policy_weights_(worker_ids=ranks, wait=True)
def _next_async_rpc(self):
if self._VERBOSE:
@@ -698,7 +708,7 @@ def _next_async_rpc(self):
future, i = self.futures.popleft()
if future.done():
if self.update_after_each_batch:
- self.update_policy_weights_(workers=(i,), wait=False)
+ self.update_policy_weights_(worker_ids=(i,), wait=False)
if self._VERBOSE:
torchrl_logger.info(f"future {i} is done")
data = future.value()
@@ -792,3 +802,102 @@ def shutdown(self):
else:
raise NotImplementedError(f"Unknown launcher {self.launcher}")
self._shutdown = True
+
+
+class RPCRemoteWeightUpdater(RemoteWeightUpdaterBase):
+ """A remote weight updater for synchronizing policy weights across remote workers using RPC.
+
+ The `RPCRemoteWeightUpdater` class provides a mechanism for updating the weights of a policy
+ across remote inference workers using RPC. It is designed to work with the :class:`~torchrl.collectors.distributed.RPCDataCollector`
+ to ensure that each worker receives the latest policy weights.
+ This class is typically used in distributed data collection scenarios where remote workers
+ are managed via RPC and need to be kept in sync with the central policy weights.
+
+ Args:
+ collector_infos: Information about the collectors, used for RPC communication.
+ collector_class: The class of the collectors being used.
+ collector_rrefs: Remote references to the collectors.
+ policy_weights (TensorDictBase): The current weights of the policy that need to be distributed
+ to the workers.
+ num_workers (int): The number of remote workers that will receive the updated policy weights.
+
+ Methods:
+ update_weights: Updates the weights on specified or all remote workers using RPC.
+ all_worker_ids: Returns a list of all worker identifiers (not implemented in this class).
+ _sync_weights_with_worker: Synchronizes the server weights with a specific worker (not implemented).
+ _get_server_weights: Retrieves the latest weights from the server (not implemented).
+ _maybe_map_weights: Optionally maps server weights before distribution (not implemented).
+
+ .. note::
+ This class assumes that the server weights can be directly applied to the remote workers
+ without any additional processing. If your use case requires more complex weight mapping or
+ synchronization logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation.
+
+ .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and
+ :class:`~torchrl.collectors.distributed.RPCDataCollector`.
+
+ """
+
+ _VERBOSE = VERBOSE # for debugging
+
+ def __init__(
+ self,
+ collector_infos,
+ collector_class,
+ collector_rrefs,
+ policy_weights: TensorDictBase,
+ num_workers: int,
+ ):
+ super().__init__()
+ self.collector_infos = collector_infos
+ self.collector_class = collector_class
+ self.collector_rrefs = collector_rrefs
+ self.policy_weights = policy_weights
+ self.num_workers = num_workers
+
+ def _sync_weights_with_worker(
+ self, worker_id: int | torch.device, server_weights: TensorDictBase
+ ) -> TensorDictBase:
+ raise NotImplementedError
+
+ def _get_server_weights(self) -> TensorDictBase:
+ raise NotImplementedError
+
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ raise NotImplementedError
+
+ def all_worker_ids(self) -> list[int] | list[torch.device]:
+ raise NotImplementedError
+
+ def update_weights(
+ self,
+ weights: TensorDictBase | None = None,
+ worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
+ **kwargs,
+ ):
+ workers = worker_ids
+ if isinstance(workers, int):
+ workers = [workers]
+ if workers is None:
+ workers = list(range(self.num_workers))
+ else:
+ workers = list(workers)
+ futures = []
+ weights = self.policy_weights if weights is None else weights
+ for i in workers:
+ if self._VERBOSE:
+ torchrl_logger.info(f"calling update on worker {i}")
+ futures.append(
+ rpc.rpc_async(
+ self.collector_infos[i],
+ self.collector_class.update_policy_weights_,
+ args=(self.collector_rrefs[i], weights),
+ )
+ )
+ if kwargs.get("wait", True):
+ for i in workers:
+ if self._VERBOSE:
+ torchrl_logger.info(f"waiting for worker {i}")
+ futures[i].wait()
+ if self._VERBOSE:
+ torchrl_logger.info("got it!")
diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py
index b15045f8288..e0cf3dfe953 100644
--- a/torchrl/collectors/distributed/sync.py
+++ b/torchrl/collectors/distributed/sync.py
@@ -601,7 +601,13 @@ def _iterator_dist(self):
data = self.postproc(data)
yield data
- def update_policy_weights_(self, worker_rank=None) -> None:
+ def update_policy_weights_(
+ self,
+ policy_weights: TensorDictBase | None = None,
+ *,
+ worker_ids=None,
+ wait=True,
+ ) -> None:
raise NotImplementedError
def set_seed(self, seed: int, static_seed: bool = False) -> int:
diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py
index 74bea267c22..536c960c14b 100644
--- a/torchrl/collectors/utils.py
+++ b/torchrl/collectors/utils.py
@@ -13,7 +13,8 @@
_NON_NN_POLICY_WEIGHTS = (
"The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and "
- "update_policy_weights_ will be a no-op."
+ "update_policy_weights_ will be a no-op. Consider passing a local/remote_weight_updater object "
+ "to your collector to handle the weight updates."
)
diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py
new file mode 100644
index 00000000000..9911c3228af
--- /dev/null
+++ b/torchrl/collectors/weight_update.py
@@ -0,0 +1,355 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import abc
+from abc import abstractmethod
+from typing import Callable, Dict, List, TypeVar
+
+import torch
+from tensordict import TensorDictBase
+from tensordict.nn import TensorDictModuleBase
+
+Policy = TypeVar("Policy", bound=TensorDictModuleBase)
+
+
+class LocalWeightUpdaterBase(metaclass=abc.ABCMeta):
+ """A base class for updating local policy weights from a server.
+
+ This class provides an interface for downloading and updating the weights of a policy
+ on a local inference worker. The update process is decentralized, meaning the inference
+ worker is responsible for fetching the weights from the server.
+
+ To extend this class, implement the following abstract methods:
+
+ - `_get_server_weights`: Define how to retrieve the weights from the server.
+ - `_get_local_weights`: Define how to access the current local weights.
+ - `_maybe_map_weights`: Optionally transform the server weights to match the local weights.
+
+ Attributes:
+ policy (Policy, optional): The policy whose weights are to be updated.
+ get_weights_from_policy (Callable, optional): A function to extract weights from the policy.
+ get_weights_from_server (Callable, optional): A function to fetch weights from the server.
+ weight_map_fn (Callable, optional): A function to map server weights to local weights.
+ cache_policy_weights (bool): Whether to cache the policy weights locally.
+
+ Methods:
+ update_weights: Updates the local weights with the server weights.
+
+
+ .. seealso:: :class:`~torchrl.collectors.RemoteWeightsUpdaterBase` and
+ :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
+
+ """
+
+ @abstractmethod
+ def _get_server_weights(self) -> TensorDictBase:
+ ...
+
+ @abstractmethod
+ def _get_local_weights(self) -> TensorDictBase:
+ ...
+
+ @abstractmethod
+ def _maybe_map_weights(
+ self, server_weights: TensorDictBase, local_weights: TensorDictBase
+ ) -> TensorDictBase:
+ ...
+
+ def _update_local_weights(
+ self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
+ ) -> TensorDictBase:
+ local_weights.update_(mapped_weights)
+
+ def __call__(
+ self,
+ weights: TensorDictBase | None = None,
+ ):
+ return self.update_weights(weights=weights)
+
+ def update_weights(self, weights: TensorDictBase | None = None) -> TensorDictBase:
+ if weights is None:
+ # get server weights (source)
+ server_weights = self._get_server_weights()
+ else:
+ server_weights = weights
+ # Get local weights
+ local_weights = self._get_local_weights()
+
+ # Optionally map the weights
+ mapped_weights = self._maybe_map_weights(server_weights, local_weights)
+
+ # Update the weights
+ self._update_local_weights(local_weights, mapped_weights)
+
+
+class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta):
+ """A base class for updating remote policy weights on inference workers.
+
+ This class provides an interface for uploading and synchronizing the weights of a policy
+ across remote inference workers. The update process is centralized, meaning the server
+ is responsible for distributing the weights to the inference nodes.
+
+ To extend this class, implement the following abstract methods:
+
+ - `_sync_weights_with_worker`: Define how to synchronize weights with a specific worker.
+ - `_get_server_weights`: Define how to retrieve the weights from the server.
+ - `_maybe_map_weights`: Optionally transform the server weights before distribution.
+ - `all_worker_ids`: Provide a list of all worker identifiers.
+
+ Attributes:
+ policy (Policy, optional): The policy whose weights are to be updated.
+
+ Methods:
+ update_weights: Updates the weights on specified or all remote workers.
+
+ .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
+ :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`.
+
+ """
+
+ @abstractmethod
+ def _sync_weights_with_worker(
+ self, worker_id: int | torch.device, server_weights: TensorDictBase
+ ) -> TensorDictBase:
+ ...
+
+ @abstractmethod
+ def _get_server_weights(self) -> TensorDictBase:
+ ...
+
+ @abstractmethod
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ ...
+
+ @abstractmethod
+ def all_worker_ids(self) -> list[int] | List[torch.device]:
+ ...
+
+ def _skip_update(self, worker_id: int | torch.device) -> bool:
+ return False
+
+ def __call__(
+ self,
+ weights: TensorDictBase | None = None,
+ worker_ids: torch.device | int | List[int] | List[torch.device] | None = None,
+ ):
+ return self.update_weights(weights=weights, worker_ids=worker_ids)
+
+ def update_weights(
+ self,
+ weights: TensorDictBase | None = None,
+ worker_ids: torch.device | int | List[int] | List[torch.device] | None = None,
+ ):
+ if weights is None:
+ # Get the weights on server (local)
+ server_weights = self._get_server_weights()
+ else:
+ server_weights = weights
+
+ self._maybe_map_weights(server_weights)
+
+ # Get the remote weights (inference workers)
+ if isinstance(worker_ids, (int, torch.device)):
+ worker_ids = [worker_ids]
+ elif worker_ids is None:
+ worker_ids = self.all_worker_ids()
+ for worker in worker_ids:
+ if self._skip_update(worker):
+ continue
+ self._sync_weights_with_worker(worker, server_weights)
+
+
+# Specialized classes
+class VanillaLocalWeightUpdater(LocalWeightUpdaterBase):
+ """A simple implementation of `LocalWeightUpdaterBase` for updating local policy weights.
+
+ The `VanillaLocalWeightUpdater` class provides a basic mechanism for updating the weights
+ of a local policy by directly fetching them from a specified source. It is typically used
+ in scenarios where the weight update logic is straightforward and does not require any
+ complex mapping or transformation.
+
+ This class is used by default in the `SyncDataCollector` when no custom local weights updater
+ is provided.
+
+ Args:
+ weight_getter (Callable[[], TensorDictBase]): A callable that returns the latest policy
+ weights from the server or another source.
+ policy_weights (TensorDictBase): The current weights of the local policy that need to be updated.
+
+ Methods:
+ _get_server_weights: Retrieves the latest weights from the specified source.
+ _get_local_weights: Accesses the current local policy weights.
+ _map_weights: Directly maps server weights to local weights without transformation.
+ _maybe_map_weights: Optionally maps server weights to local weights (no-op in this implementation).
+ _update_local_weights: Updates the local policy weights with the mapped weights.
+
+ .. note::
+ This class assumes that the server weights can be directly applied to the local policy
+ without any additional processing. If your use case requires more complex weight mapping,
+ consider extending `LocalWeightUpdaterBase` with a custom implementation.
+
+ .. seealso:: :class:`~torchrl.collectors.LocalWeightUpdaterBase` and :class:`~torchrl.collectors.SyncDataCollector`.
+ """
+
+ def __init__(
+ self,
+ weight_getter: Callable[[], TensorDictBase],
+ policy_weights: TensorDictBase,
+ ):
+ self.weight_getter = weight_getter
+ self.policy_weights = policy_weights
+
+ def _get_server_weights(self) -> TensorDictBase:
+ return self.weight_getter() if self.weight_getter is not None else None
+
+ def _get_local_weights(self) -> TensorDictBase:
+ return self.policy_weights
+
+ def _map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ return server_weights
+
+ def _maybe_map_weights(
+ self, server_weights: TensorDictBase, local_weights: TensorDictBase
+ ) -> TensorDictBase:
+ return server_weights
+
+ def _update_local_weights(
+ self, local_weights: TensorDictBase, mapped_weights: TensorDictBase
+ ) -> TensorDictBase:
+ if local_weights is None or mapped_weights is None:
+ return
+ local_weights.update_(mapped_weights)
+
+
+class MultiProcessedRemoteWeightUpdate(RemoteWeightUpdaterBase):
+ """A remote weight updater for synchronizing policy weights across multiple processes or devices.
+
+ The `MultiProcessedRemoteWeightUpdate` class provides a mechanism for updating the weights
+ of a policy across multiple inference workers in a multiprocessed environment. It is designed
+ to handle the distribution of weights from a central server to various devices or processes
+ that are running the policy.
+ This class is typically used in multiprocessed data collectors where each process or device
+ requires an up-to-date copy of the policy weights.
+
+ Args:
+ get_server_weights (Callable[[], TensorDictBase] | None): A callable that retrieves the
+ latest policy weights from the server or another centralized source.
+ policy_weights (Dict[torch.device, TensorDictBase]): A dictionary mapping each device or
+ process to its current policy weights, which will be updated.
+
+ Methods:
+ all_worker_ids: Returns a list of all worker identifiers (devices or processes).
+ _sync_weights_with_worker: Synchronizes the server weights with a specific worker.
+ _get_server_weights: Retrieves the latest weights from the server.
+ _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation).
+
+ .. note::
+ This class assumes that the server weights can be directly applied to the workers without
+ any additional processing. If your use case requires more complex weight mapping or synchronization
+ logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation.
+
+ .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and
+ :class:`~torchrl.collectors.DataCollectorBase`.
+
+ """
+
+ def __init__(
+ self,
+ get_server_weights: Callable[[], TensorDictBase] | None,
+ policy_weights: Dict[torch.device, TensorDictBase],
+ ):
+ self.weights_getter = get_server_weights
+ self._policy_weights = policy_weights
+
+ def all_worker_ids(self) -> list[int] | List[torch.device]:
+ return list(self._policy_weights)
+
+ def _sync_weights_with_worker(
+ self, worker_id: int | torch.device, server_weights: TensorDictBase
+ ) -> TensorDictBase:
+ if server_weights is None:
+ return
+ self._policy_weights[worker_id].data.update_(server_weights)
+
+ def _get_server_weights(self) -> TensorDictBase:
+ # The weights getter can be none if no mapping is required
+ if self.weights_getter is None:
+ return
+ weights = self.weights_getter()
+ if weights is None:
+ return
+ return weights.data
+
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ return server_weights
+
+
+class RayRemoteWeightUpdater(RemoteWeightUpdaterBase):
+ """A remote weight updater for synchronizing policy weights across remote workers using Ray.
+
+ The `RayRemoteWeightUpdater` class provides a mechanism for updating the weights of a policy
+ across remote inference workers managed by Ray. It leverages Ray's distributed computing
+ capabilities to efficiently distribute policy weights to remote collectors.
+ This class is typically used in distributed data collectors where each remote worker requires
+ an up-to-date copy of the policy weights.
+
+ Args:
+ policy_weights (TensorDictBase): The current weights of the policy that need to be distributed
+ to remote workers.
+ remote_collectors (List): A list of remote collectors that will receive the updated policy weights.
+ max_interval (int, optional): The maximum number of batches between weight updates for each worker.
+ Defaults to 0, meaning weights are updated every batch.
+
+ Methods:
+ all_worker_ids: Returns a list of all worker identifiers (indices of remote collectors).
+ _get_server_weights: Retrieves the latest weights from the server and stores them in Ray's object store.
+ _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation).
+ _sync_weights_with_worker: Synchronizes the server weights with a specific remote worker using Ray.
+ _skip_update: Determines whether to skip the weight update for a specific worker based on the interval.
+
+ .. note::
+ This class assumes that the server weights can be directly applied to the remote workers without
+ any additional processing. If your use case requires more complex weight mapping or synchronization
+ logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation.
+
+ .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and
+ :class:`~torchrl.collectors.distributed.RayCollector`.
+
+ """
+
+ def __init__(
+ self,
+ policy_weights: TensorDictBase,
+ remote_collectors: List,
+ max_interval: int = 0,
+ ):
+ self.policy_weights = policy_weights
+ self.remote_collectors = remote_collectors
+ self.max_interval = max(0, max_interval)
+ self._batches_since_weight_update = [0] * len(self.remote_collectors)
+
+ def all_worker_ids(self) -> list[int] | List[torch.device]:
+ return list(range(len(self.remote_collectors)))
+
+ def _get_server_weights(self) -> TensorDictBase:
+ import ray
+
+ return ray.put(self.policy_weights.data)
+
+ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase:
+ return server_weights
+
+ def _sync_weights_with_worker(
+ self, worker_id: int, server_weights: TensorDictBase
+ ) -> TensorDictBase:
+ c = self.remote_collectors[worker_id]
+ c.update_policy_weights_.remote(policy_weights=server_weights)
+ self._batches_since_weight_update[worker_id] = 0
+
+ def _skip_update(self, worker_id: int) -> bool:
+ self._batches_since_weight_update[worker_id] += 1
+ # Use gt because we just incremented it
+ if self._batches_since_weight_update[worker_id] > self.max_interval:
+ return False
+ return True