Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions test/compile/test_collectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for torch.compile compatibility of collectors."""
from __future__ import annotations

import functools
import sys

import pytest
import torch
from packaging import version
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.collectors import Collector, MultiAsyncCollector, MultiSyncCollector
from torchrl.testing.mocking_classes import ContinuousActionVecMockEnv

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
IS_WINDOWS = sys.platform == "win32"

pytestmark = [
pytest.mark.filterwarnings(
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
),
]


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.")
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
)
class TestCompile:
@pytest.mark.parametrize(
"collector_cls",
# Clearing compiled policies causes segfault on machines with cuda
[Collector, MultiAsyncCollector, MultiSyncCollector]
if not torch.cuda.is_available()
else [Collector],
)
@pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}])
@pytest.mark.parametrize(
"device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")]
)
def test_compiled_policy(self, collector_cls, compile_policy, device):
policy = TensorDictModule(
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
)
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
if collector_cls is Collector:
torch._dynamo.reset_code_caches()
collector = Collector(
make_env(),
policy,
frames_per_batch=10,
total_frames=30,
compile_policy=compile_policy,
)
assert collector.compiled_policy
else:
collector = collector_cls(
[make_env] * 2,
policy,
frames_per_batch=10,
total_frames=30,
compile_policy=compile_policy,
)
assert collector.compiled_policy
try:
for data in collector:
assert data is not None
finally:
collector.shutdown()
del collector

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
@pytest.mark.parametrize(
"collector_cls",
[Collector],
)
@pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}])
def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
device = torch.device("cuda:0")
policy = TensorDictModule(
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
)
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
if collector_cls is Collector:
collector = Collector(
make_env(),
policy,
frames_per_batch=30,
total_frames=120,
cudagraph_policy=cudagraph_policy,
device=device,
)
assert collector.cudagraphed_policy
else:
collector = collector_cls(
[make_env] * 2,
policy,
frames_per_batch=30,
total_frames=120,
cudagraph_policy=cudagraph_policy,
device=device,
)
assert collector.cudagraphed_policy
try:
for data in collector:
assert data is not None
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
pytest.main([__file__, "-v"])
104 changes: 104 additions & 0 deletions test/compile/test_objectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for torch.compile compatibility of objectives-related modules."""
from __future__ import annotations

import sys

import pytest
import torch

from packaging import version
from tensordict import TensorDict
from tensordict.nn import ProbabilisticTensorDictModule, set_composite_lp_aggregate

from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
IS_WINDOWS = sys.platform == "win32"

pytestmark = [
pytest.mark.filterwarnings(
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
),
]


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
)
@set_composite_lp_aggregate(False)
def test_exploration_compile():
try:
torch._dynamo.reset_code_caches()
except Exception:
# older versions of PT don't have that function
pass
m = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=torch.distributions.Normal,
)

# class set_exploration_type_random(set_exploration_type):
# __init__ = object.__init__
# type = ExplorationType.RANDOM
it = exploration_type()

@torch.compile(fullgraph=True)
def func(t):
with set_exploration_type(ExplorationType.RANDOM):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] != t1["sample"]).any()
assert it == exploration_type()

@torch.compile(fullgraph=True)
def func(t):
with set_exploration_type(ExplorationType.MEAN):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] == t1["sample"]).all()
assert it == exploration_type()

@torch.compile(fullgraph=True)
@set_exploration_type(ExplorationType.RANDOM)
def func(t):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] != t1["sample"]).any()
assert it == exploration_type()

@torch.compile(fullgraph=True)
@set_exploration_type(ExplorationType.MEAN)
def func(t):
t0 = m(t.clone())
t1 = m(t.clone())
return t0, t1

t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
t0, t1 = func(t)
assert (t0["sample"] == t1["sample"]).all()
assert it == exploration_type()


if __name__ == "__main__":
pytest.main([__file__, "-v"])
58 changes: 58 additions & 0 deletions test/compile/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for torch.compile compatibility of utility functions."""
from __future__ import annotations

import sys

import pytest
import torch
from packaging import version

from torchrl.testing import capture_log_records

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

pytestmark = [
pytest.mark.filterwarnings(
"ignore:`torch.jit.script_method` is deprecated:DeprecationWarning"
),
]


# Check that 'capture_log_records' captures records emitted when torch
# recompiles a function.
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
@pytest.mark.skipif(
sys.version_info >= (3, 14),
reason="torch.compile is not supported on Python 3.14+",
)
def test_capture_log_records_recompile():
torch.compiler.reset()

# This function recompiles each time it is called with a different string
# input.
@torch.compile
def str_to_tensor(s):
return bytes(s, "utf8")

str_to_tensor("a")

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")
str_to_tensor("b")

finally:
torch._logging.set_logs()

assert len(records) == 1


if __name__ == "__main__":
pytest.main([__file__, "-v"])
90 changes: 0 additions & 90 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3693,96 +3693,6 @@ def test_dynamic_multiasync_collector(self):
assert data.names[-1] == "time"


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.")
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+"
)
class TestCompile:
@pytest.mark.parametrize(
"collector_cls",
# Clearing compiled policies causes segfault on machines with cuda
[Collector, MultiAsyncCollector, MultiSyncCollector]
if not torch.cuda.is_available()
else [Collector],
)
@pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}])
@pytest.mark.parametrize(
"device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")]
)
def test_compiled_policy(self, collector_cls, compile_policy, device):
policy = TensorDictModule(
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
)
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
if collector_cls is Collector:
torch._dynamo.reset_code_caches()
collector = Collector(
make_env(),
policy,
frames_per_batch=10,
total_frames=30,
compile_policy=compile_policy,
)
assert collector.compiled_policy
else:
collector = collector_cls(
[make_env] * 2,
policy,
frames_per_batch=10,
total_frames=30,
compile_policy=compile_policy,
)
assert collector.compiled_policy
try:
for data in collector:
assert data is not None
finally:
collector.shutdown()
del collector

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
@pytest.mark.parametrize(
"collector_cls",
[Collector],
)
@pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}])
def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
device = torch.device("cuda:0")
policy = TensorDictModule(
nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"]
)
make_env = functools.partial(ContinuousActionVecMockEnv, device=device)
if collector_cls is Collector:
collector = Collector(
make_env(),
policy,
frames_per_batch=30,
total_frames=120,
cudagraph_policy=cudagraph_policy,
device=device,
)
assert collector.cudagraphed_policy
else:
collector = collector_cls(
[make_env] * 2,
policy,
frames_per_batch=30,
total_frames=120,
cudagraph_policy=cudagraph_policy,
device=device,
)
assert collector.cudagraphed_policy
try:
for data in collector:
assert data is not None
finally:
collector.shutdown()
del collector


@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
class TestCollectorsNonTensor:
class AddNontTensorData(Transform):
Expand Down
Loading
Loading