diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index 50d7e35ed52..61109bca4f8 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -117,6 +117,76 @@ try to limit the cases where a deepcopy will be executed. The following chart sh Policy copy decision tree in Collectors. +Weight Synchronization in Distributed Environments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the +latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible +mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. + +Local and Remote Weight Updaters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase` +and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for +implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. + +- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on + the local inference worker. It is particularly useful when the training and inference occur on the same machine but on + different devices. Users can extend this class to define how weights are fetched from a server and applied locally. + It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with + situations where the server decides when to update the worker policies). +- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to + remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with + the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of + devices or processes. + +Extending the Updater Classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. +This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware +setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved, +transformed, and applied, ensuring seamless integration with their existing infrastructure. + +Default Implementations +~~~~~~~~~~~~~~~~~~~~~~~ + +For common scenarios, the API provides default implementations of these updaters, such as +:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`, +:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and +:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`. +These implementations cover a range of typical deployment configurations, from single-device setups to large-scale +distributed systems. + +Practical Considerations +~~~~~~~~~~~~~~~~~~~~~~~~ + +When designing a system that leverages this API, consider the following: + +- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your + implementation accounts for potential delays and optimizes data transfer where possible. +- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across + the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to + suboptimal policy performance. +- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the + overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks. + +By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment +scenarios, ensuring that their policies remain up-to-date and performant. + +.. currentmodule:: torchrl.collectors + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + LocalWeightUpdaterBase + RemoteWeightUpdaterBase + VanillaLocalWeightUpdater + MultiProcessedRemoteWeightUpdate + RayRemoteWeightUpdater + DistributedRemoteWeightUpdater + RPCRemoteWeightUpdater Collectors and replay buffers interoperability ---------------------------------------------- diff --git a/test/test_distributed.py b/test/test_distributed.py index c49596bb26f..41797ae066f 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -390,25 +390,29 @@ def _test_distributed_collector_updatepolicy( update_interval=update_interval, **cls.distributed_kwargs(), ) - total = 0 - first_batch = None - last_batch = None - for i, data in enumerate(collector): - total += data.numel() - assert data.numel() == frames_per_batch - if i == 0: - first_batch = data - policy.weight.data += 1 - elif total == total_frames - frames_per_batch: - last_batch = data - assert (first_batch["action"] == 1).all(), first_batch["action"] - if update_interval == 1: - assert (last_batch["action"] == 2).all(), last_batch["action"] - else: - assert (last_batch["action"] == 1).all(), last_batch["action"] - collector.shutdown() - assert total == total_frames - queue.put("passed") + try: + + total = 0 + first_batch = None + last_batch = None + for i, data in enumerate(collector): + total += data.numel() + assert data.numel() == frames_per_batch + if i == 0: + first_batch = data + policy.weight.data += 1 + elif total == total_frames - frames_per_batch: + last_batch = data + assert (first_batch["action"] == 1).all(), first_batch["action"] + if update_interval == 1: + assert (last_batch["action"] == 2).all(), last_batch["action"] + else: + assert (last_batch["action"] == 1).all(), last_batch["action"] + assert total == total_frames + queue.put("passed") + finally: + collector.shutdown() + queue.put("not passed") @pytest.mark.parametrize( "collector_class", @@ -490,12 +494,14 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200): sync=sync, **self.distributed_kwargs(), ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 + try: + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + assert total == 200 + finally: + collector.shutdown() @pytest.mark.parametrize( "collector_class", @@ -517,12 +523,14 @@ def test_distributed_collector_class(self, collector_class): frames_per_batch=frames_per_batch, **self.distributed_kwargs(), ) - total = 0 - for data in collector: - total += data.numel() - assert data.numel() == frames_per_batch - collector.shutdown() - assert total == 200 + try: + total = 0 + for data in collector: + total += data.numel() + assert data.numel() == frames_per_batch + assert total == 200 + finally: + collector.shutdown() @pytest.mark.parametrize( "collector_class", diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 2d40522bb07..b3170e3631a 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -12,9 +12,21 @@ MultiSyncDataCollector, SyncDataCollector, ) +from .weight_update import ( + LocalWeightUpdaterBase, + MultiProcessedRemoteWeightUpdate, + RayRemoteWeightUpdater, + RemoteWeightUpdaterBase, + VanillaLocalWeightUpdater, +) __all__ = [ "RandomPolicy", + "LocalWeightUpdaterBase", + "RemoteWeightUpdaterBase", + "VanillaLocalWeightUpdater", + "RayRemoteWeightUpdater", + "MultiProcessedRemoteWeightUpdate", "aSyncDataCollector", "DataCollectorBase", "MultiaSyncDataCollector", diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 58f67229413..3ff35dbb560 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -26,12 +26,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict import ( - LazyStackedTensorDict, - TensorDict, - TensorDictBase, - TensorDictParams, -) +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.base import NO_DEFAULT from tensordict.nn import CudaGraphModule, TensorDictModule from tensordict.utils import Buffer @@ -53,6 +48,12 @@ VERBOSE, ) from torchrl.collectors.utils import split_trajectories +from torchrl.collectors.weight_update import ( + LocalWeightUpdaterBase, + MultiProcessedRemoteWeightUpdate, + RemoteWeightUpdaterBase, + VanillaLocalWeightUpdater, +) from torchrl.data import ReplayBuffer from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -151,6 +152,8 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): trust_policy: bool compiled_policy: bool cudagraphed_policy: bool + local_weights_updater: LocalWeightUpdaterBase | None = None + remote_weights_updater: RemoteWeightUpdaterBase | None = None def _get_policy_and_device( self, @@ -242,19 +245,50 @@ def map_weight( return policy, get_original_weights def update_policy_weights_( - self, policy_weights: TensorDictBase | None = None + self, + policy_weights: TensorDictBase | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, ) -> None: - """Updates the policy weights if the policy of the data collector and the trained policy live on different devices. + """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. + + This method ensures that the policy weights used by the data collector are synchronized with the latest + trained weights. It supports both local and remote weight updates, depending on the configuration of the + data collector. The local (download) update is performed before the remote (upload) update, such that weights + can be transferred to the children workers from a server. Args: - policy_weights (TensorDictBase, optional): if provided, a TensorDict containing - the weights of the policy to be used for the udpdate. + policy_weights (TensorDictBase, optional): A TensorDict containing the weights of the policy to be used + for the update. If not provided, the method will attempt to fetch the weights using the configured + weight updater. + worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the + workers that need to be updated. This is relevant when using a remote weights updater, which must + be specified during the data collector's initialization. If `worker_ids` is provided without a + configured remote weights updater, a TypeError will be raised. + + Raises: + TypeError: If `worker_ids` is provided but no `remote_weights_updater` is configured. + + .. note:: + + - The method first attempts to update weights locally using `local_weights_updater`, if available. + - If a `remote_weights_updater` is configured, it will be used to update the specified remote workers. + - Users can extend the `LocalWeightUpdaterBase` and `RemoteWeightUpdaterBase` classes to customize + the weight update logic for specific use cases. This method should not be overwritten. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. """ - if policy_weights is not None: - self.policy_weights.data.update_(policy_weights) - elif self.get_weights_fn is not None: - self.policy_weights.data.update_(self.get_weights_fn()) + if self.local_weights_updater is not None: + self.local_weights_updater(policy_weights, **kwargs) + if self.remote_weights_updater is not None: + self.remote_weights_updater(policy_weights, worker_ids=worker_ids, **kwargs) + elif worker_ids is not None: + raise TypeError( + "worker_ids was passed but remote_weights_updater was None." + ) def __iter__(self) -> Iterator[TensorDictBase]: try: @@ -459,6 +493,13 @@ class SyncDataCollector(DataCollectorBase): or `ManiSkills `_) cuda synchronization may cause unexpected crashes. Defaults to ``False``. + local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on the local inference worker. + If not provided, a :class:`~torchrl.collectors.VanillaLocalWeightUpdater` will be used by default, + which directly fetches and applies the weights from the server. + remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -549,6 +590,8 @@ def __init__( compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, + local_weights_updater: LocalWeightUpdaterBase | None = None, + remote_weights_updater: RemoteWeightUpdaterBase | None = None, **kwargs, ): from torchrl.envs.batched_envs import BatchedEnvBase @@ -788,6 +831,14 @@ def __init__( self._frames = 0 self._iter = -1 + if local_weights_updater is None: + local_weights_updater = VanillaLocalWeightUpdater( + weight_getter=self.get_weights_fn, policy_weights=self.policy_weights + ) + + self.local_weights_updater = local_weights_updater + self.remote_weights_updater = remote_weights_updater + @property def _traj_pool(self): pool = getattr(self, "_traj_pool_val", None) @@ -998,9 +1049,14 @@ def next(self): # for RPC def update_policy_weights_( - self, policy_weights: TensorDictBase | None = None + self, + policy_weights: TensorDictBase | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, ) -> None: - super().update_policy_weights_(policy_weights) + super().update_policy_weights_( + policy_weights=policy_weights, worker_ids=worker_ids + ) def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seeds of the environments stored in the DataCollector. @@ -1635,6 +1691,13 @@ class _MultiDataCollector(DataCollectorBase): or `ManiSkills `_) cuda synchronization may cause unexpected crashes. Defaults to ``False``. + local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on each local inference worker. + If not provided, left unused. + remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + If not provided, a :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate` will be used by default, + which handles weight synchronization across multiple processes. """ @@ -1672,6 +1735,8 @@ def __init__( compile_policy: bool | dict[str, Any] | None = None, cudagraph_policy: bool | dict[str, Any] | None = None, no_cuda_sync: bool = False, + remote_weights_updater: RemoteWeightUpdaterBase | None = None, + local_weights_updater: LocalWeightUpdaterBase | None = None, ): self.closed = True self.num_workers = len(create_env_fn) @@ -1725,7 +1790,6 @@ def __init__( replay_buffer.share() self._policy_weights_dict = {} - self._get_weights_fn_dict = {} if trust_policy is None: trust_policy = policy is not None and isinstance(policy, CudaGraphModule) @@ -1751,13 +1815,21 @@ def __init__( else TensorDict() ) self._policy_weights_dict[policy_device] = weights - self._get_weights_fn_dict[policy_device] = get_weights_fn - else: + self._get_weights_fn = get_weights_fn + if remote_weights_updater is None: + remote_weights_updater = MultiProcessedRemoteWeightUpdate( + get_server_weights=self._get_weights_fn, + policy_weights=self._policy_weights_dict, + ) + elif remote_weights_updater is None: # TODO raise NotImplementedError( - "weight syncing is not supported for multiprocessed data collectors at the " - "moment." + "remote_weights_updater cannot be None when policy_factory is provided." ) + + self.remote_weights_updater = remote_weights_updater + self.local_weights_updater = local_weights_updater + self.policy = policy remainder = 0 @@ -1913,21 +1985,6 @@ def _get_devices( def frames_per_batch_worker(self): raise NotImplementedError - def update_policy_weights_(self, policy_weights=None) -> None: - if isinstance(policy_weights, TensorDictParams): - policy_weights = policy_weights.data - for _device in self._policy_weights_dict: - if policy_weights is not None: - self._policy_weights_dict[_device].data.update_(policy_weights) - elif self._get_weights_fn_dict[_device] is not None: - original_weights = self._get_weights_fn_dict[_device]() - if original_weights is None: - # if the weights match in identity, we can spare a call to update_ - continue - if isinstance(original_weights, TensorDictParams): - original_weights = original_weights.data - self._policy_weights_dict[_device].data.update_(original_weights) - @property def _queue_len(self) -> int: raise NotImplementedError @@ -1958,6 +2015,10 @@ def _run_processes(self) -> None: policy_device = self.policy_device[i] storing_device = self.storing_device[i] env_device = self.env_device[i] + # We take the weights, the policy, and locally dispatch the weights to the policy + # while we send the policy to the remote process. + # This makes sure that a given set of shared weights for a given device are + # shared for all policies that rely on that device. policy = self.policy policy_weights = self._policy_weights_dict[policy_device] if policy is not None and policy_weights is not None: @@ -2312,9 +2373,14 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: TensorDictBase | None = None + self, + policy_weights: TensorDictBase | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, ) -> None: - super().update_policy_weights_(policy_weights) + super().update_policy_weights_( + policy_weights=policy_weights, worker_ids=worker_ids + ) @property def frames_per_batch_worker(self): @@ -2676,9 +2742,14 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: # for RPC def update_policy_weights_( - self, policy_weights: TensorDictBase | None = None + self, + policy_weights: TensorDictBase | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, ) -> None: - super().update_policy_weights_(policy_weights) + super().update_policy_weights_( + policy_weights=policy_weights, worker_ids=worker_ids + ) @property def frames_per_batch_worker(self): diff --git a/torchrl/collectors/distributed/__init__.py b/torchrl/collectors/distributed/__init__.py index c28122c6c6a..79a52b1698a 100644 --- a/torchrl/collectors/distributed/__init__.py +++ b/torchrl/collectors/distributed/__init__.py @@ -3,7 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .generic import DEFAULT_SLURM_CONF, DistributedDataCollector +from .generic import ( + DEFAULT_SLURM_CONF, + DistributedDataCollector, + DistributedRemoteWeightUpdater, +) from .ray import RayCollector from .rpc import RPCDataCollector from .sync import DistributedSyncDataCollector @@ -12,8 +16,10 @@ __all__ = [ "DEFAULT_SLURM_CONF", "DistributedDataCollector", - "RayCollector", - "RPCDataCollector", + "DistributedRemoteWeightUpdater", "DistributedSyncDataCollector", + "RPCDataCollector", + "RPCDataCollector", + "RayCollector", "submitit_delayed_launcher", ] diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 00ea799d55f..79fdee553c3 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -31,6 +31,10 @@ TCP_PORT, ) from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.weight_update import ( + LocalWeightUpdaterBase, + RemoteWeightUpdaterBase, +) from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator @@ -227,7 +231,7 @@ def _run_collector( # been updated policy_weights.irecv(0) # the policy has been updated: we can simply update the weights - collector.update_policy_weights_(policy_weights) + collector.update_policy_weights_(policy_weights=policy_weights) _store.set(f"NODE_{rank}_out", b"updated") elif instruction.startswith(b"seeding"): seed = int(instruction.split(b"seeding_")) @@ -407,6 +411,15 @@ class DistributedDataCollector(DataCollectorBase): to learn more. Defaults to ``"submitit"``. tcp_port (int, optional): the TCP port to be used. Defaults to 10003. + local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on the local inference worker. + This is typically not used in :class:`~torchrl.collectors.distributed.DistributedDataCollector` as it + focuses on distributed environments. + remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on distributed inference workers. + If not provided, a :class:`~torchrl.collectors.distributed.DistributedRemoteWeightUpdater` will be used by + default, which handles weight synchronization across distributed workers. + """ _VERBOSE = VERBOSE # for debugging @@ -439,6 +452,8 @@ def __init__( max_weight_update_interval: int = -1, launcher: str = "submitit", tcp_port: int = None, + remote_weights_updater: RemoteWeightUpdaterBase | None = None, + local_weights_updater: LocalWeightUpdaterBase | None = None, ): if collector_class == "async": @@ -453,6 +468,13 @@ def __init__( if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) policy_weights = policy_weights.data.lock_() + elif policy_factory is not None: + policy_weights = None + if remote_weights_updater is None: + raise RuntimeError( + "remote_weights_updater must be passed along with " + "a policy_factory." + ) else: warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) @@ -534,6 +556,15 @@ def __init__( self._init_workers() self._make_container() + if remote_weights_updater is None: + remote_weights_updater = DistributedRemoteWeightUpdater( + store=self._store, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + sync=self._sync, + ) + self.remote_weights_updater = remote_weights_updater + self.local_weights_updater = local_weights_updater @property def device(self) -> list[torch.device]: @@ -817,7 +848,9 @@ def _iterator_dist(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - self.update_policy_weights_(rank) + self.update_policy_weights_( + policy_weights=None, worker_ids=rank + ) for i in range(self.num_workers): rank = i + 1 @@ -866,7 +899,7 @@ def _next_async(self, total_frames, trackers): _tracker.wait() data = self._tensordict_out[i].clone() if self.update_after_each_batch: - self.update_policy_weights_(rank) + self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() if total_frames < self.total_frames: if self._VERBOSE: @@ -880,31 +913,6 @@ def _next_async(self, total_frames, trackers): break return data, total_frames - def update_policy_weights_(self, worker_rank=None) -> None: - """Updates the weights of the worker nodes. - - Args: - worker_rank (int, optional): if provided, only this worker weights - will be updated. - """ - if worker_rank is not None and worker_rank < 1: - raise RuntimeError("worker_rank must be greater than 1") - workers = range(self.num_workers) if worker_rank is None else [worker_rank - 1] - for i in workers: - rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"updating weights of {rank}") - self._store.set(f"NODE_{rank}_in", b"update_weights") - if self._sync: - self.policy_weights.send(rank) - else: - self.policy_weights.isend(rank) - self._batches_since_weight_update[i] = 0 - status = self._store.get(f"NODE_{rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{rank}_out") - def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 @@ -950,3 +958,98 @@ def shutdown(self): pass if self._VERBOSE: torchrl_logger.info("collector shut down") + + +class DistributedRemoteWeightUpdater(RemoteWeightUpdaterBase): + """A remote weight updater for synchronizing policy weights across distributed workers. + + The `DistributedRemoteWeightUpdater` class provides a mechanism for updating the weights + of a policy across distributed inference workers. It is designed to work with the + :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights. + This class is typically used in distributed data collection scenarios where multiple workers + need to be kept in sync with the central policy weights. + + Args: + store (dict[str, str]): A dictionary-like store used for communication between the server + and the distributed workers. + policy_weights (TensorDictBase): The current weights of the policy that need to be distributed + to the workers. + num_workers (int): The number of distributed workers that will receive the updated policy weights. + sync (bool): if ``True``, the sync happens synchronously (the server waits for the worker to have completed + the update to restart the run). + + Methods: + update_weights: Updates the weights on specified or all distributed workers. + all_worker_ids: Returns a list of all worker identifiers (not implemented in this class). + _sync_weights_with_worker: Synchronizes the server weights with a specific worker (not implemented). + _get_server_weights: Retrieves the latest weights from the server (not implemented). + _maybe_map_weights: Optionally maps server weights before distribution (not implemented). + + .. note:: + This class assumes that the server weights can be directly applied to the distributed workers + without any additional processing. If your use case requires more complex weight mapping or + synchronization logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation. + + Raises: + RuntimeError: If the worker rank is less than 1 or if the status returned by the store is not "updated". + + .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and + :class:`~torchrl.collectors.distributed.DistributedDataCollector`. + + """ + + _VERBOSE = True + + def __init__( + self, + store: dict[str, str], + policy_weights: TensorDictBase, + num_workers: int, + sync: bool, + ): + self._store = store + self.policy_weights = policy_weights + self.num_workers = num_workers + self._sync = sync + self._batches_since_weight_update = [0 for _ in range(self.num_workers)] + + def _sync_weights_with_worker( + self, worker_id: int | torch.device, server_weights: TensorDictBase + ) -> TensorDictBase: + raise NotImplementedError + + def _get_server_weights(self) -> TensorDictBase: + raise NotImplementedError + + def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + raise NotImplementedError + + def all_worker_ids(self) -> list[int] | list[torch.device]: + raise NotImplementedError + + def update_weights( + self, + weights: TensorDictBase | None = None, + worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, + ): + worker_rank = worker_ids + if isinstance(worker_ids, int): + if worker_rank is not None and worker_rank < 1: + raise RuntimeError("worker_rank must be greater than 1") + worker_rank = [worker_rank - 1] + workers = range(self.num_workers) if worker_rank is None else worker_rank + weights = self.policy_weights if weights is None else weights + for i in workers: + rank = i + 1 + if self._VERBOSE: + torchrl_logger.info(f"updating weights of {rank}") + self._store.set(f"NODE_{rank}_in", b"update_weights") + if self._sync: + weights.send(rank) + else: + weights.isend(rank) + self._batches_since_weight_update[i] = 0 + status = self._store.get(f"NODE_{rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{rank}_out") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 279ef9e4aae..eb43abf7428 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -22,6 +22,11 @@ SyncDataCollector, ) from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.weight_update import ( + LocalWeightUpdaterBase, + RayRemoteWeightUpdater, + RemoteWeightUpdaterBase, +) from torchrl.data import ReplayBuffer from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator @@ -255,9 +260,10 @@ class RayCollector(DataCollectorBase): all workers will see their weights updated. For ``sync=False``, only the worker from which the data has been gathered will be updated. + This is equivalent to `max_weight_update_interval=0`. Defaults to ``False``, i.e. updates have to be executed manually through - ``torchrl.collectors.distributed.RayDistributedCollector.update_policy_weights_()`` + :meth:`torchrl.collectors.DataCollector.update_policy_weights_` max_weight_update_interval (int, optional): the maximum number of batches that can be collected before the policy weights of a worker is updated. @@ -271,6 +277,13 @@ class RayCollector(DataCollectorBase): .. note:: although it is not enfoced (to allow users to implement their own replay buffer class), a :class:`~torchrl.data.RayReplayBuffer` instance should be used here. + local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on the local inference worker. + This is typically not used in :class:`~torchrl.collectors.RayCollector` as it focuses on distributed environments. + remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers managed by Ray. + If not provided, a :class:`~torchrl.collectors.RayRemoteWeightUpdater` will be used by default, leveraging + Ray's distributed capabilities. Examples: >>> from torch import nn @@ -329,6 +342,8 @@ def __init__( update_after_each_batch=False, max_weight_update_interval=-1, replay_buffer: ReplayBuffer = None, + remote_weights_updater: RemoteWeightUpdaterBase | None = None, + local_weights_updater: LocalWeightUpdaterBase | None = None, ): self.frames_per_batch = frames_per_batch if remote_configs is None: @@ -441,16 +456,19 @@ def check_list_length_consistency(*lists): policy_weights = TensorDict.from_module(self._local_policy) policy_weights = policy_weights.data.lock_() else: - warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + if remote_weights_updater is None: + warnings.warn(_NON_NN_POLICY_WEIGHTS) self.policy_weights = policy_weights self.collector_class = collector_class self.collected_frames = 0 self.split_trajs = split_trajs self.total_frames = total_frames self.num_collectors = num_collectors + self.update_after_each_batch = update_after_each_batch self.max_weight_update_interval = max_weight_update_interval + self.collector_kwargs = ( collector_kwargs if collector_kwargs is not None else [{}] ) @@ -507,10 +525,18 @@ def check_list_length_consistency(*lists): collector_kwargs, remote_configs, ) + if remote_weights_updater is None: + remote_weights_updater = RayRemoteWeightUpdater( + policy_weights=policy_weights, + remote_collectors=self.remote_collectors, + max_interval=self.max_weight_update_interval, + ) + self.remote_weights_updater = remote_weights_updater + self.local_weights_updater = local_weights_updater # Print info of all remote workers pending_samples = [ - e.print_remote_collector_info.remote() for e in self.remote_collectors() + e.print_remote_collector_info.remote() for e in self.remote_collectors ] ray.wait(pending_samples) @@ -605,6 +631,7 @@ def local_policy(self): """Returns local collector.""" return self._local_policy + @property def remote_collectors(self): """Returns list of remote collectors.""" return self._remote_collectors @@ -612,7 +639,7 @@ def remote_collectors(self): def stop_remote_collectors(self): """Stops all remote collectors.""" for _ in range(len(self._remote_collectors)): - collector = self.remote_collectors().pop() + collector = self.remote_collectors.pop() # collector.__ray_terminate__.remote() # This will kill the actor but let pending tasks finish ray.kill( collector @@ -650,14 +677,11 @@ def proc(data): def _sync_iterator(self) -> Iterator[TensorDictBase]: """Collects one data batch per remote collector in each iteration.""" while self.collected_frames < self.total_frames: - if self.update_after_each_batch: + if self.update_after_each_batch or self.max_weight_update_interval > -1: self.update_policy_weights_() - else: - for j in range(self.num_collectors): - self._batches_since_weight_update[j] += 1 # Ask for batches to all remote workers. - pending_tasks = [e.next.remote() for e in self.remote_collectors()] + pending_tasks = [e.next.remote() for e in self.remote_collectors] # Wait for all rollouts samples_ready = [] @@ -683,15 +707,6 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: yield out_td - if self.max_weight_update_interval > -1: - for j in range(self.num_collectors): - rank = j + 1 - if ( - self._batches_since_weight_update[j] - > self.max_weight_update_interval - ): - self.update_policy_weights_(rank) - if self._task is None: self.shutdown() @@ -720,19 +735,19 @@ async def async_shutdown(self): def _async_iterator(self) -> Iterator[TensorDictBase]: """Collects a data batch from a single remote collector in each iteration.""" pending_tasks = {} - for index, collector in enumerate(self.remote_collectors()): + for index, collector in enumerate(self.remote_collectors): future = collector.next.remote() pending_tasks[future] = index while self.collected_frames < self.total_frames: - if not len(list(pending_tasks.keys())) == len(self.remote_collectors()): + if not len(list(pending_tasks.keys())) == len(self.remote_collectors): raise RuntimeError("Missing pending tasks, something went wrong") # Wait for first worker to finish wait_results = ray.wait(list(pending_tasks.keys())) future = wait_results[0][0] collector_index = pending_tasks.pop(future) - collector = self.remote_collectors()[collector_index] + collector = self.remote_collectors[collector_index] # Retrieve single rollouts out_td = ray.get(future) @@ -743,18 +758,8 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td - for j in range(self.num_collectors): - self._batches_since_weight_update[j] += 1 - if self.update_after_each_batch: - self.update_policy_weights_(worker_rank=collector_index + 1) - elif self.max_weight_update_interval > -1: - for j in range(self.num_collectors): - rank = j + 1 - if ( - self._batches_since_weight_update[j] - > self.max_weight_update_interval - ): - self.update_policy_weights_(rank) + if self.update_after_each_batch or self.max_weight_update_interval > -1: + self.update_policy_weights_(worker_ids=collector_index + 1) # Schedule a new collection task future = collector.next.remote() @@ -773,36 +778,16 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: if self._task is None: self.shutdown() - def update_policy_weights_(self, worker_rank=None) -> None: - """Updates the weights of the worker nodes. - - Args: - worker_rank (int, optional): if provided, only this worker weights - will be updated. - """ - # Update agent weights - policy_weights_local_collector_ref = ray.put(self.policy_weights.detach()) - - if worker_rank is None: - for index, e in enumerate(self.remote_collectors()): - e.update_policy_weights_.remote(policy_weights_local_collector_ref) - self._batches_since_weight_update[index] = 0 - else: - self.remote_collectors()[worker_rank - 1].update_policy_weights_.remote( - policy_weights_local_collector_ref - ) - self._batches_since_weight_update[worker_rank - 1] = 0 - def set_seed(self, seed: int, static_seed: bool = False) -> list[int]: """Calls parent method for each remote collector iteratively and returns final seed.""" - for collector in self.remote_collectors(): + for collector in self.remote_collectors: seed = ray.get(object_refs=collector.set_seed.remote(seed, static_seed)) return seed def state_dict(self) -> list[OrderedDict]: """Calls parent method for each remote collector and returns a list of results.""" futures = [ - collector.state_dict.remote() for collector in self.remote_collectors() + collector.state_dict.remote() for collector in self.remote_collectors ] results = ray.get(object_refs=futures) return results @@ -812,8 +797,8 @@ def load_state_dict(self, state_dict: OrderedDict | list[OrderedDict]) -> None: if isinstance(state_dict, OrderedDict): state_dicts = [state_dict] if len(state_dict) == 1: - state_dicts = state_dict * len(self.remote_collectors()) - for collector, state_dict in zip(self.remote_collectors(), state_dicts): + state_dicts = state_dict * len(self.remote_collectors) + for collector, state_dict in zip(self.remote_collectors, state_dicts): collector.load_state_dict.remote(state_dict) def shutdown(self): diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index f205325ed30..7d198375251 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -47,6 +47,10 @@ MultiSyncDataCollector, SyncDataCollector, ) +from torchrl.collectors.weight_update import ( + LocalWeightUpdaterBase, + RemoteWeightUpdaterBase, +) from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator @@ -262,6 +266,14 @@ class RPCDataCollector(DataCollectorBase): device used to pass data to main. tensorpipe_options (dict, optional): a dictionary of keyword argument to pass to :class:`torch.distributed.rpc.TensorPipeRpcBackendOption`. + local_weights_updater (LocalWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.LocalWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on the local inference worker. This is + typically not used in :class:`~torchrl.collectors.distrbibuted.RPCDataCollector` as it focuses on + distributed environments. + remote_weights_updater (RemoteWeightUpdaterBase, optional): An instance of :class:`~torchrl.collectors.RemoteWeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers using RPC. + If not provided, an :class:`~torchrl.collectors.distributed.RPCRemoteWeightUpdater` will be used by default, which + handles weight synchronization via RPC. """ @@ -296,6 +308,8 @@ def __init__( tcp_port=None, visible_devices=None, tensorpipe_options=None, + remote_weights_updater: RemoteWeightUpdaterBase | None = None, + local_weights_updater: LocalWeightUpdaterBase | None = None, ): if collector_class == "async": collector_class = MultiaSyncDataCollector @@ -393,6 +407,16 @@ def __init__( tensorpipe_options ) self._init() + if remote_weights_updater is None: + remote_weights_updater = RPCRemoteWeightUpdater( + collector_infos=self.collector_infos, + collector_class=self.collector_class, + collector_rrefs=self.collector_rrefs, + policy_weights=self.policy_weights, + num_workers=self.num_workers, + ) + self.local_weights_updater = local_weights_updater + self.remote_weights_updater = remote_weights_updater @property def device(self) -> list[torch.device]: @@ -655,7 +679,11 @@ def iterator(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - self.update_policy_weights_([j], wait=False) + if self._VERBOSE: + torchrl_logger.info( + f"Updating policy of worker {j} with wait=False" + ) + self.update_policy_weights_(worker_ids=[j], wait=False) elif self.max_weight_update_interval > -1: ranks = [ 1 @@ -663,29 +691,11 @@ def iterator(self): if self._batches_since_weight_update[j] > self.max_weight_update_interval ] - self.update_policy_weights_(ranks, wait=True) - - def update_policy_weights_(self, workers=None, wait=True) -> None: - if workers is None: - workers = list(range(self.num_workers)) - futures = [] - for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"calling update on worker {i}") - futures.append( - rpc.rpc_async( - self.collector_infos[i], - self.collector_class.update_policy_weights_, - args=(self.collector_rrefs[i], self.policy_weights.detach()), - ) - ) - if wait: - for i in workers: if self._VERBOSE: - torchrl_logger.info(f"waiting for worker {i}") - futures[i].wait() - if self._VERBOSE: - torchrl_logger.info("got it!") + torchrl_logger.info( + f"Updating policy of workers {ranks} with wait=True" + ) + self.update_policy_weights_(worker_ids=ranks, wait=True) def _next_async_rpc(self): if self._VERBOSE: @@ -698,7 +708,7 @@ def _next_async_rpc(self): future, i = self.futures.popleft() if future.done(): if self.update_after_each_batch: - self.update_policy_weights_(workers=(i,), wait=False) + self.update_policy_weights_(worker_ids=(i,), wait=False) if self._VERBOSE: torchrl_logger.info(f"future {i} is done") data = future.value() @@ -792,3 +802,102 @@ def shutdown(self): else: raise NotImplementedError(f"Unknown launcher {self.launcher}") self._shutdown = True + + +class RPCRemoteWeightUpdater(RemoteWeightUpdaterBase): + """A remote weight updater for synchronizing policy weights across remote workers using RPC. + + The `RPCRemoteWeightUpdater` class provides a mechanism for updating the weights of a policy + across remote inference workers using RPC. It is designed to work with the :class:`~torchrl.collectors.distributed.RPCDataCollector` + to ensure that each worker receives the latest policy weights. + This class is typically used in distributed data collection scenarios where remote workers + are managed via RPC and need to be kept in sync with the central policy weights. + + Args: + collector_infos: Information about the collectors, used for RPC communication. + collector_class: The class of the collectors being used. + collector_rrefs: Remote references to the collectors. + policy_weights (TensorDictBase): The current weights of the policy that need to be distributed + to the workers. + num_workers (int): The number of remote workers that will receive the updated policy weights. + + Methods: + update_weights: Updates the weights on specified or all remote workers using RPC. + all_worker_ids: Returns a list of all worker identifiers (not implemented in this class). + _sync_weights_with_worker: Synchronizes the server weights with a specific worker (not implemented). + _get_server_weights: Retrieves the latest weights from the server (not implemented). + _maybe_map_weights: Optionally maps server weights before distribution (not implemented). + + .. note:: + This class assumes that the server weights can be directly applied to the remote workers + without any additional processing. If your use case requires more complex weight mapping or + synchronization logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation. + + .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and + :class:`~torchrl.collectors.distributed.RPCDataCollector`. + + """ + + _VERBOSE = VERBOSE # for debugging + + def __init__( + self, + collector_infos, + collector_class, + collector_rrefs, + policy_weights: TensorDictBase, + num_workers: int, + ): + super().__init__() + self.collector_infos = collector_infos + self.collector_class = collector_class + self.collector_rrefs = collector_rrefs + self.policy_weights = policy_weights + self.num_workers = num_workers + + def _sync_weights_with_worker( + self, worker_id: int | torch.device, server_weights: TensorDictBase + ) -> TensorDictBase: + raise NotImplementedError + + def _get_server_weights(self) -> TensorDictBase: + raise NotImplementedError + + def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + raise NotImplementedError + + def all_worker_ids(self) -> list[int] | list[torch.device]: + raise NotImplementedError + + def update_weights( + self, + weights: TensorDictBase | None = None, + worker_ids: torch.device | int | list[int] | list[torch.device] | None = None, + **kwargs, + ): + workers = worker_ids + if isinstance(workers, int): + workers = [workers] + if workers is None: + workers = list(range(self.num_workers)) + else: + workers = list(workers) + futures = [] + weights = self.policy_weights if weights is None else weights + for i in workers: + if self._VERBOSE: + torchrl_logger.info(f"calling update on worker {i}") + futures.append( + rpc.rpc_async( + self.collector_infos[i], + self.collector_class.update_policy_weights_, + args=(self.collector_rrefs[i], weights), + ) + ) + if kwargs.get("wait", True): + for i in workers: + if self._VERBOSE: + torchrl_logger.info(f"waiting for worker {i}") + futures[i].wait() + if self._VERBOSE: + torchrl_logger.info("got it!") diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index b15045f8288..e0cf3dfe953 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -601,7 +601,13 @@ def _iterator_dist(self): data = self.postproc(data) yield data - def update_policy_weights_(self, worker_rank=None) -> None: + def update_policy_weights_( + self, + policy_weights: TensorDictBase | None = None, + *, + worker_ids=None, + wait=True, + ) -> None: raise NotImplementedError def set_seed(self, seed: int, static_seed: bool = False) -> int: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 74bea267c22..536c960c14b 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -13,7 +13,8 @@ _NON_NN_POLICY_WEIGHTS = ( "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " - "update_policy_weights_ will be a no-op." + "update_policy_weights_ will be a no-op. Consider passing a local/remote_weight_updater object " + "to your collector to handle the weight updates." ) diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py new file mode 100644 index 00000000000..9911c3228af --- /dev/null +++ b/torchrl/collectors/weight_update.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import abc +from abc import abstractmethod +from typing import Callable, Dict, List, TypeVar + +import torch +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase + +Policy = TypeVar("Policy", bound=TensorDictModuleBase) + + +class LocalWeightUpdaterBase(metaclass=abc.ABCMeta): + """A base class for updating local policy weights from a server. + + This class provides an interface for downloading and updating the weights of a policy + on a local inference worker. The update process is decentralized, meaning the inference + worker is responsible for fetching the weights from the server. + + To extend this class, implement the following abstract methods: + + - `_get_server_weights`: Define how to retrieve the weights from the server. + - `_get_local_weights`: Define how to access the current local weights. + - `_maybe_map_weights`: Optionally transform the server weights to match the local weights. + + Attributes: + policy (Policy, optional): The policy whose weights are to be updated. + get_weights_from_policy (Callable, optional): A function to extract weights from the policy. + get_weights_from_server (Callable, optional): A function to fetch weights from the server. + weight_map_fn (Callable, optional): A function to map server weights to local weights. + cache_policy_weights (bool): Whether to cache the policy weights locally. + + Methods: + update_weights: Updates the local weights with the server weights. + + + .. seealso:: :class:`~torchrl.collectors.RemoteWeightsUpdaterBase` and + :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`. + + """ + + @abstractmethod + def _get_server_weights(self) -> TensorDictBase: + ... + + @abstractmethod + def _get_local_weights(self) -> TensorDictBase: + ... + + @abstractmethod + def _maybe_map_weights( + self, server_weights: TensorDictBase, local_weights: TensorDictBase + ) -> TensorDictBase: + ... + + def _update_local_weights( + self, local_weights: TensorDictBase, mapped_weights: TensorDictBase + ) -> TensorDictBase: + local_weights.update_(mapped_weights) + + def __call__( + self, + weights: TensorDictBase | None = None, + ): + return self.update_weights(weights=weights) + + def update_weights(self, weights: TensorDictBase | None = None) -> TensorDictBase: + if weights is None: + # get server weights (source) + server_weights = self._get_server_weights() + else: + server_weights = weights + # Get local weights + local_weights = self._get_local_weights() + + # Optionally map the weights + mapped_weights = self._maybe_map_weights(server_weights, local_weights) + + # Update the weights + self._update_local_weights(local_weights, mapped_weights) + + +class RemoteWeightUpdaterBase(metaclass=abc.ABCMeta): + """A base class for updating remote policy weights on inference workers. + + This class provides an interface for uploading and synchronizing the weights of a policy + across remote inference workers. The update process is centralized, meaning the server + is responsible for distributing the weights to the inference nodes. + + To extend this class, implement the following abstract methods: + + - `_sync_weights_with_worker`: Define how to synchronize weights with a specific worker. + - `_get_server_weights`: Define how to retrieve the weights from the server. + - `_maybe_map_weights`: Optionally transform the server weights before distribution. + - `all_worker_ids`: Provide a list of all worker identifiers. + + Attributes: + policy (Policy, optional): The policy whose weights are to be updated. + + Methods: + update_weights: Updates the weights on specified or all remote workers. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.DataCollectorBase.update_policy_weights_`. + + """ + + @abstractmethod + def _sync_weights_with_worker( + self, worker_id: int | torch.device, server_weights: TensorDictBase + ) -> TensorDictBase: + ... + + @abstractmethod + def _get_server_weights(self) -> TensorDictBase: + ... + + @abstractmethod + def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + ... + + @abstractmethod + def all_worker_ids(self) -> list[int] | List[torch.device]: + ... + + def _skip_update(self, worker_id: int | torch.device) -> bool: + return False + + def __call__( + self, + weights: TensorDictBase | None = None, + worker_ids: torch.device | int | List[int] | List[torch.device] | None = None, + ): + return self.update_weights(weights=weights, worker_ids=worker_ids) + + def update_weights( + self, + weights: TensorDictBase | None = None, + worker_ids: torch.device | int | List[int] | List[torch.device] | None = None, + ): + if weights is None: + # Get the weights on server (local) + server_weights = self._get_server_weights() + else: + server_weights = weights + + self._maybe_map_weights(server_weights) + + # Get the remote weights (inference workers) + if isinstance(worker_ids, (int, torch.device)): + worker_ids = [worker_ids] + elif worker_ids is None: + worker_ids = self.all_worker_ids() + for worker in worker_ids: + if self._skip_update(worker): + continue + self._sync_weights_with_worker(worker, server_weights) + + +# Specialized classes +class VanillaLocalWeightUpdater(LocalWeightUpdaterBase): + """A simple implementation of `LocalWeightUpdaterBase` for updating local policy weights. + + The `VanillaLocalWeightUpdater` class provides a basic mechanism for updating the weights + of a local policy by directly fetching them from a specified source. It is typically used + in scenarios where the weight update logic is straightforward and does not require any + complex mapping or transformation. + + This class is used by default in the `SyncDataCollector` when no custom local weights updater + is provided. + + Args: + weight_getter (Callable[[], TensorDictBase]): A callable that returns the latest policy + weights from the server or another source. + policy_weights (TensorDictBase): The current weights of the local policy that need to be updated. + + Methods: + _get_server_weights: Retrieves the latest weights from the specified source. + _get_local_weights: Accesses the current local policy weights. + _map_weights: Directly maps server weights to local weights without transformation. + _maybe_map_weights: Optionally maps server weights to local weights (no-op in this implementation). + _update_local_weights: Updates the local policy weights with the mapped weights. + + .. note:: + This class assumes that the server weights can be directly applied to the local policy + without any additional processing. If your use case requires more complex weight mapping, + consider extending `LocalWeightUpdaterBase` with a custom implementation. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightUpdaterBase` and :class:`~torchrl.collectors.SyncDataCollector`. + """ + + def __init__( + self, + weight_getter: Callable[[], TensorDictBase], + policy_weights: TensorDictBase, + ): + self.weight_getter = weight_getter + self.policy_weights = policy_weights + + def _get_server_weights(self) -> TensorDictBase: + return self.weight_getter() if self.weight_getter is not None else None + + def _get_local_weights(self) -> TensorDictBase: + return self.policy_weights + + def _map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + return server_weights + + def _maybe_map_weights( + self, server_weights: TensorDictBase, local_weights: TensorDictBase + ) -> TensorDictBase: + return server_weights + + def _update_local_weights( + self, local_weights: TensorDictBase, mapped_weights: TensorDictBase + ) -> TensorDictBase: + if local_weights is None or mapped_weights is None: + return + local_weights.update_(mapped_weights) + + +class MultiProcessedRemoteWeightUpdate(RemoteWeightUpdaterBase): + """A remote weight updater for synchronizing policy weights across multiple processes or devices. + + The `MultiProcessedRemoteWeightUpdate` class provides a mechanism for updating the weights + of a policy across multiple inference workers in a multiprocessed environment. It is designed + to handle the distribution of weights from a central server to various devices or processes + that are running the policy. + This class is typically used in multiprocessed data collectors where each process or device + requires an up-to-date copy of the policy weights. + + Args: + get_server_weights (Callable[[], TensorDictBase] | None): A callable that retrieves the + latest policy weights from the server or another centralized source. + policy_weights (Dict[torch.device, TensorDictBase]): A dictionary mapping each device or + process to its current policy weights, which will be updated. + + Methods: + all_worker_ids: Returns a list of all worker identifiers (devices or processes). + _sync_weights_with_worker: Synchronizes the server weights with a specific worker. + _get_server_weights: Retrieves the latest weights from the server. + _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation). + + .. note:: + This class assumes that the server weights can be directly applied to the workers without + any additional processing. If your use case requires more complex weight mapping or synchronization + logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation. + + .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and + :class:`~torchrl.collectors.DataCollectorBase`. + + """ + + def __init__( + self, + get_server_weights: Callable[[], TensorDictBase] | None, + policy_weights: Dict[torch.device, TensorDictBase], + ): + self.weights_getter = get_server_weights + self._policy_weights = policy_weights + + def all_worker_ids(self) -> list[int] | List[torch.device]: + return list(self._policy_weights) + + def _sync_weights_with_worker( + self, worker_id: int | torch.device, server_weights: TensorDictBase + ) -> TensorDictBase: + if server_weights is None: + return + self._policy_weights[worker_id].data.update_(server_weights) + + def _get_server_weights(self) -> TensorDictBase: + # The weights getter can be none if no mapping is required + if self.weights_getter is None: + return + weights = self.weights_getter() + if weights is None: + return + return weights.data + + def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + return server_weights + + +class RayRemoteWeightUpdater(RemoteWeightUpdaterBase): + """A remote weight updater for synchronizing policy weights across remote workers using Ray. + + The `RayRemoteWeightUpdater` class provides a mechanism for updating the weights of a policy + across remote inference workers managed by Ray. It leverages Ray's distributed computing + capabilities to efficiently distribute policy weights to remote collectors. + This class is typically used in distributed data collectors where each remote worker requires + an up-to-date copy of the policy weights. + + Args: + policy_weights (TensorDictBase): The current weights of the policy that need to be distributed + to remote workers. + remote_collectors (List): A list of remote collectors that will receive the updated policy weights. + max_interval (int, optional): The maximum number of batches between weight updates for each worker. + Defaults to 0, meaning weights are updated every batch. + + Methods: + all_worker_ids: Returns a list of all worker identifiers (indices of remote collectors). + _get_server_weights: Retrieves the latest weights from the server and stores them in Ray's object store. + _maybe_map_weights: Optionally maps server weights before distribution (no-op in this implementation). + _sync_weights_with_worker: Synchronizes the server weights with a specific remote worker using Ray. + _skip_update: Determines whether to skip the weight update for a specific worker based on the interval. + + .. note:: + This class assumes that the server weights can be directly applied to the remote workers without + any additional processing. If your use case requires more complex weight mapping or synchronization + logic, consider extending `RemoteWeightUpdaterBase` with a custom implementation. + + .. seealso:: :class:`~torchrl.collectors.RemoteWeightUpdaterBase` and + :class:`~torchrl.collectors.distributed.RayCollector`. + + """ + + def __init__( + self, + policy_weights: TensorDictBase, + remote_collectors: List, + max_interval: int = 0, + ): + self.policy_weights = policy_weights + self.remote_collectors = remote_collectors + self.max_interval = max(0, max_interval) + self._batches_since_weight_update = [0] * len(self.remote_collectors) + + def all_worker_ids(self) -> list[int] | List[torch.device]: + return list(range(len(self.remote_collectors))) + + def _get_server_weights(self) -> TensorDictBase: + import ray + + return ray.put(self.policy_weights.data) + + def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: + return server_weights + + def _sync_weights_with_worker( + self, worker_id: int, server_weights: TensorDictBase + ) -> TensorDictBase: + c = self.remote_collectors[worker_id] + c.update_policy_weights_.remote(policy_weights=server_weights) + self._batches_since_weight_update[worker_id] = 0 + + def _skip_update(self, worker_id: int) -> bool: + self._batches_since_weight_update[worker_id] += 1 + # Use gt because we just incremented it + if self._batches_since_weight_update[worker_id] > self.max_interval: + return False + return True