diff --git a/test/compile/test_collectors.py b/test/compile/test_collectors.py new file mode 100644 index 00000000000..703e9dc7b91 --- /dev/null +++ b/test/compile/test_collectors.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for torch.compile compatibility of collectors.""" +from __future__ import annotations + +import functools +import sys + +import pytest +import torch +from packaging import version +from tensordict.nn import TensorDictModule +from torch import nn + +from torchrl.collectors import Collector, MultiAsyncCollector, MultiSyncCollector +from torchrl.testing.mocking_classes import ContinuousActionVecMockEnv + +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +IS_WINDOWS = sys.platform == "win32" + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:`torch.jit.script_method` is deprecated:DeprecationWarning" + ), +] + + +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" +) +@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.") +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+" +) +class TestCompile: + @pytest.mark.parametrize( + "collector_cls", + # Clearing compiled policies causes segfault on machines with cuda + [Collector, MultiAsyncCollector, MultiSyncCollector] + if not torch.cuda.is_available() + else [Collector], + ) + @pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}]) + @pytest.mark.parametrize( + "device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")] + ) + def test_compiled_policy(self, collector_cls, compile_policy, device): + policy = TensorDictModule( + nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] + ) + make_env = functools.partial(ContinuousActionVecMockEnv, device=device) + if collector_cls is Collector: + torch._dynamo.reset_code_caches() + collector = Collector( + make_env(), + policy, + frames_per_batch=10, + total_frames=30, + compile_policy=compile_policy, + ) + assert collector.compiled_policy + else: + collector = collector_cls( + [make_env] * 2, + policy, + frames_per_batch=10, + total_frames=30, + compile_policy=compile_policy, + ) + assert collector.compiled_policy + try: + for data in collector: + assert data is not None + finally: + collector.shutdown() + del collector + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + @pytest.mark.parametrize( + "collector_cls", + [Collector], + ) + @pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}]) + def test_cudagraph_policy(self, collector_cls, cudagraph_policy): + device = torch.device("cuda:0") + policy = TensorDictModule( + nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] + ) + make_env = functools.partial(ContinuousActionVecMockEnv, device=device) + if collector_cls is Collector: + collector = Collector( + make_env(), + policy, + frames_per_batch=30, + total_frames=120, + cudagraph_policy=cudagraph_policy, + device=device, + ) + assert collector.cudagraphed_policy + else: + collector = collector_cls( + [make_env] * 2, + policy, + frames_per_batch=30, + total_frames=120, + cudagraph_policy=cudagraph_policy, + device=device, + ) + assert collector.cudagraphed_policy + try: + for data in collector: + assert data is not None + finally: + collector.shutdown() + del collector + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/compile/test_objectives.py b/test/compile/test_objectives.py new file mode 100644 index 00000000000..45ee3f2bdd7 --- /dev/null +++ b/test/compile/test_objectives.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for torch.compile compatibility of objectives-related modules.""" +from __future__ import annotations + +import sys + +import pytest +import torch + +from packaging import version +from tensordict import TensorDict +from tensordict.nn import ProbabilisticTensorDictModule, set_composite_lp_aggregate + +from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type + +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +IS_WINDOWS = sys.platform == "win32" + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:`torch.jit.script_method` is deprecated:DeprecationWarning" + ), +] + + +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) +@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+" +) +@set_composite_lp_aggregate(False) +def test_exploration_compile(): + try: + torch._dynamo.reset_code_caches() + except Exception: + # older versions of PT don't have that function + pass + m = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + out_keys=["sample"], + distribution_class=torch.distributions.Normal, + ) + + # class set_exploration_type_random(set_exploration_type): + # __init__ = object.__init__ + # type = ExplorationType.RANDOM + it = exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.RANDOM): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + def func(t): + with set_exploration_type(ExplorationType.MEAN): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.RANDOM) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] != t1["sample"]).any() + assert it == exploration_type() + + @torch.compile(fullgraph=True) + @set_exploration_type(ExplorationType.MEAN) + def func(t): + t0 = m(t.clone()) + t1 = m(t.clone()) + return t0, t1 + + t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) + t0, t1 = func(t) + assert (t0["sample"] == t1["sample"]).all() + assert it == exploration_type() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/compile/test_utils.py b/test/compile/test_utils.py new file mode 100644 index 00000000000..934134f51a6 --- /dev/null +++ b/test/compile/test_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for torch.compile compatibility of utility functions.""" +from __future__ import annotations + +import sys + +import pytest +import torch +from packaging import version + +from torchrl.testing import capture_log_records + +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:`torch.jit.script_method` is deprecated:DeprecationWarning" + ), +] + + +# Check that 'capture_log_records' captures records emitted when torch +# recompiles a function. +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" +) +@pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="torch.compile is not supported on Python 3.14+", +) +def test_capture_log_records_recompile(): + torch.compiler.reset() + + # This function recompiles each time it is called with a different string + # input. + @torch.compile + def str_to_tensor(s): + return bytes(s, "utf8") + + str_to_tensor("a") + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + str_to_tensor("b") + + finally: + torch._logging.set_logs() + + assert len(records) == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_collectors.py b/test/test_collectors.py index 7b3604a6615..90037bb7fb5 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -3693,96 +3693,6 @@ def test_dynamic_multiasync_collector(self): assert data.names[-1] == "time" -@pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" -) -@pytest.mark.skipif(IS_WINDOWS, reason="windows is not supported for compile tests.") -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+" -) -class TestCompile: - @pytest.mark.parametrize( - "collector_cls", - # Clearing compiled policies causes segfault on machines with cuda - [Collector, MultiAsyncCollector, MultiSyncCollector] - if not torch.cuda.is_available() - else [Collector], - ) - @pytest.mark.parametrize("compile_policy", [True, {}, {"mode": "default"}]) - @pytest.mark.parametrize( - "device", [torch.device("cuda:0" if torch.cuda.is_available() else "cpu")] - ) - def test_compiled_policy(self, collector_cls, compile_policy, device): - policy = TensorDictModule( - nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] - ) - make_env = functools.partial(ContinuousActionVecMockEnv, device=device) - if collector_cls is Collector: - torch._dynamo.reset_code_caches() - collector = Collector( - make_env(), - policy, - frames_per_batch=10, - total_frames=30, - compile_policy=compile_policy, - ) - assert collector.compiled_policy - else: - collector = collector_cls( - [make_env] * 2, - policy, - frames_per_batch=10, - total_frames=30, - compile_policy=compile_policy, - ) - assert collector.compiled_policy - try: - for data in collector: - assert data is not None - finally: - collector.shutdown() - del collector - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") - @pytest.mark.parametrize( - "collector_cls", - [Collector], - ) - @pytest.mark.parametrize("cudagraph_policy", [True, {}, {"warmup": 10}]) - def test_cudagraph_policy(self, collector_cls, cudagraph_policy): - device = torch.device("cuda:0") - policy = TensorDictModule( - nn.Linear(7, 7, device=device), in_keys=["observation"], out_keys=["action"] - ) - make_env = functools.partial(ContinuousActionVecMockEnv, device=device) - if collector_cls is Collector: - collector = Collector( - make_env(), - policy, - frames_per_batch=30, - total_frames=120, - cudagraph_policy=cudagraph_policy, - device=device, - ) - assert collector.cudagraphed_policy - else: - collector = collector_cls( - [make_env] * 2, - policy, - frames_per_batch=30, - total_frames=120, - cudagraph_policy=cudagraph_policy, - device=device, - ) - assert collector.cudagraphed_policy - try: - for data in collector: - assert data is not None - finally: - collector.shutdown() - del collector - - @pytest.mark.skipif(not _has_gym, reason="gym required for this test") class TestCollectorsNonTensor: class AddNontTensorData(Transform): diff --git a/test/test_objectives.py b/test/test_objectives.py index 70ec5067179..097418ea004 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -192,9 +192,6 @@ ), pytest.mark.filterwarnings("ignore:unclosed event loop:ResourceWarning"), pytest.mark.filterwarnings("ignore:unclosed.*socket:ResourceWarning"), - pytest.mark.filterwarnings( - "ignore:`torch.jit.script_method` is deprecated:DeprecationWarning" - ), ] @@ -18066,80 +18063,6 @@ def __init__(self): assert p.device == dest -@pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" -) -@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") -@pytest.mark.skipif( - sys.version_info >= (3, 14), reason="torch.compile is not supported on Python 3.14+" -) -@set_composite_lp_aggregate(False) -def test_exploration_compile(): - try: - torch._dynamo.reset_code_caches() - except Exception: - # older versions of PT don't have that function - pass - m = ProbabilisticTensorDictModule( - in_keys=["loc", "scale"], - out_keys=["sample"], - distribution_class=torch.distributions.Normal, - ) - - # class set_exploration_type_random(set_exploration_type): - # __init__ = object.__init__ - # type = ExplorationType.RANDOM - it = exploration_type() - - @torch.compile(fullgraph=True) - def func(t): - with set_exploration_type(ExplorationType.RANDOM): - t0 = m(t.clone()) - t1 = m(t.clone()) - return t0, t1 - - t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) - t0, t1 = func(t) - assert (t0["sample"] != t1["sample"]).any() - assert it == exploration_type() - - @torch.compile(fullgraph=True) - def func(t): - with set_exploration_type(ExplorationType.MEAN): - t0 = m(t.clone()) - t1 = m(t.clone()) - return t0, t1 - - t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) - t0, t1 = func(t) - assert (t0["sample"] == t1["sample"]).all() - assert it == exploration_type() - - @torch.compile(fullgraph=True) - @set_exploration_type(ExplorationType.RANDOM) - def func(t): - t0 = m(t.clone()) - t1 = m(t.clone()) - return t0, t1 - - t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) - t0, t1 = func(t) - assert (t0["sample"] != t1["sample"]).any() - assert it == exploration_type() - - @torch.compile(fullgraph=True) - @set_exploration_type(ExplorationType.MEAN) - def func(t): - t0 = m(t.clone()) - t1 = m(t.clone()) - return t0, t1 - - t = TensorDict(loc=torch.randn(3), scale=torch.rand(3)) - t0, t1 = func(t) - assert (t0["sample"] == t1["sample"]).all() - assert it == exploration_type() - - @set_composite_lp_aggregate(False) def test_loss_exploration(): class DummyLoss(LossModule): diff --git a/test/test_utils.py b/test/test_utils.py index d815b8ab482..27d6a0f3bdf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -21,11 +21,7 @@ from torchrl.objectives.utils import _pseudo_vmap -from torchrl.testing import ( - capture_log_records, - get_default_devices, - gym_helpers as _gym_helpers, -) +from torchrl.testing import get_default_devices, gym_helpers as _gym_helpers TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -390,38 +386,6 @@ def test_rng_decorator(device): torch.testing.assert_close(s0b, s1b) -# Check that 'capture_log_records' captures records emitted when torch -# recompiles a function. -@pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" -) -@pytest.mark.skipif( - sys.version_info >= (3, 14), - reason="torch.compile is not supported on Python 3.14+", -) -def test_capture_log_records_recompile(): - torch.compiler.reset() - - # This function recompiles each time it is called with a different string - # input. - @torch.compile - def str_to_tensor(s): - return bytes(s, "utf8") - - str_to_tensor("a") - - try: - torch._logging.set_logs(recompiles=True) - records = [] - capture_log_records(records, "torch._dynamo", "recompiles") - str_to_tensor("b") - - finally: - torch._logging.set_logs() - - assert len(records) == 1 - - def add_one(x): return x + 1