Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,76 @@ try to limit the cases where a deepcopy will be executed. The following chart sh

Policy copy decision tree in Collectors.

Weight Synchronization in Distributed Environments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.

Local and Remote Weight Updaters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase`
and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.

- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
situations where the server decides when to update the worker policies).
- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
devices or processes.

Extending the Updater Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
transformed, and applied, ensuring seamless integration with their existing infrastructure.

Default Implementations
~~~~~~~~~~~~~~~~~~~~~~~

For common scenarios, the API provides default implementations of these updaters, such as
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and
:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`.
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
distributed systems.

Practical Considerations
~~~~~~~~~~~~~~~~~~~~~~~~

When designing a system that leverages this API, consider the following:

- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
implementation accounts for potential delays and optimizes data transfer where possible.
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
suboptimal policy performance.
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.

By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
scenarios, ensuring that their policies remain up-to-date and performant.

.. currentmodule:: torchrl.collectors

.. autosummary::
:toctree: generated/
:template: rl_template.rst

LocalWeightUpdaterBase
RemoteWeightUpdaterBase
VanillaLocalWeightUpdater
MultiProcessedRemoteWeightUpdate
RayRemoteWeightUpdater
DistributedRemoteWeightUpdater
RPCRemoteWeightUpdater

Collectors and replay buffers interoperability
----------------------------------------------
Expand Down
70 changes: 39 additions & 31 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,25 +390,29 @@ def _test_distributed_collector_updatepolicy(
update_interval=update_interval,
**cls.distributed_kwargs(),
)
total = 0
first_batch = None
last_batch = None
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
if update_interval == 1:
assert (last_batch["action"] == 2).all(), last_batch["action"]
else:
assert (last_batch["action"] == 1).all(), last_batch["action"]
collector.shutdown()
assert total == total_frames
queue.put("passed")
try:

total = 0
first_batch = None
last_batch = None
for i, data in enumerate(collector):
total += data.numel()
assert data.numel() == frames_per_batch
if i == 0:
first_batch = data
policy.weight.data += 1
elif total == total_frames - frames_per_batch:
last_batch = data
assert (first_batch["action"] == 1).all(), first_batch["action"]
if update_interval == 1:
assert (last_batch["action"] == 2).all(), last_batch["action"]
else:
assert (last_batch["action"] == 1).all(), last_batch["action"]
assert total == total_frames
queue.put("passed")
finally:
collector.shutdown()
queue.put("not passed")

@pytest.mark.parametrize(
"collector_class",
Expand Down Expand Up @@ -490,12 +494,14 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200):
sync=sync,
**self.distributed_kwargs(),
)
total = 0
for data in collector:
total += data.numel()
assert data.numel() == frames_per_batch
collector.shutdown()
assert total == 200
try:
total = 0
for data in collector:
total += data.numel()
assert data.numel() == frames_per_batch
assert total == 200
finally:
collector.shutdown()

@pytest.mark.parametrize(
"collector_class",
Expand All @@ -517,12 +523,14 @@ def test_distributed_collector_class(self, collector_class):
frames_per_batch=frames_per_batch,
**self.distributed_kwargs(),
)
total = 0
for data in collector:
total += data.numel()
assert data.numel() == frames_per_batch
collector.shutdown()
assert total == 200
try:
total = 0
for data in collector:
total += data.numel()
assert data.numel() == frames_per_batch
assert total == 200
finally:
collector.shutdown()

@pytest.mark.parametrize(
"collector_class",
Expand Down
12 changes: 12 additions & 0 deletions torchrl/collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@
MultiSyncDataCollector,
SyncDataCollector,
)
from .weight_update import (
LocalWeightUpdaterBase,
MultiProcessedRemoteWeightUpdate,
RayRemoteWeightUpdater,
RemoteWeightUpdaterBase,
VanillaLocalWeightUpdater,
)

__all__ = [
"RandomPolicy",
"LocalWeightUpdaterBase",
"RemoteWeightUpdaterBase",
"VanillaLocalWeightUpdater",
"RayRemoteWeightUpdater",
"MultiProcessedRemoteWeightUpdate",
"aSyncDataCollector",
"DataCollectorBase",
"MultiaSyncDataCollector",
Expand Down
Loading
Loading