diff --git a/test/compile/test_value.py b/test/compile/test_value.py new file mode 100644 index 00000000000..5e0f0af1d25 --- /dev/null +++ b/test/compile/test_value.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for torch.compile compatibility of value estimation functions.""" +from __future__ import annotations + +import pytest +import torch + +from torchrl.objectives.value.functional import ( + generalized_advantage_estimate, + td_lambda_return_estimate, + vec_generalized_advantage_estimate, + vec_td_lambda_return_estimate, +) + + +class TestValueFunctionCompile: + """Test compilation of value estimation functions.""" + + @pytest.fixture + def value_data(self): + """Create test data for value functions.""" + batch_size = 32 + time_steps = 15 + feature_dim = 1 + + return { + "gamma": 0.99, + "lmbda": 0.95, + "state_value": torch.randn(batch_size, time_steps, feature_dim), + "next_state_value": torch.randn(batch_size, time_steps, feature_dim), + "reward": torch.randn(batch_size, time_steps, feature_dim), + "done": torch.zeros(batch_size, time_steps, feature_dim, dtype=torch.bool), + "terminated": torch.zeros( + batch_size, time_steps, feature_dim, dtype=torch.bool + ), + } + + def test_td_lambda_return_estimate_compiles_fullgraph(self, value_data): + """Test that td_lambda_return_estimate (non-vectorized) compiles with fullgraph=True.""" + result_eager = td_lambda_return_estimate( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + compiled_fn = torch.compile( + td_lambda_return_estimate, + fullgraph=True, + backend="inductor", + ) + + result_compiled = compiled_fn( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4) + + def test_generalized_advantage_estimate_compiles_fullgraph(self, value_data): + """Test that generalized_advantage_estimate (non-vectorized) compiles with fullgraph=True.""" + advantage_eager, value_target_eager = generalized_advantage_estimate( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + state_value=value_data["state_value"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + compiled_fn = torch.compile( + generalized_advantage_estimate, + fullgraph=True, + backend="inductor", + ) + + advantage_compiled, value_target_compiled = compiled_fn( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + state_value=value_data["state_value"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + torch.testing.assert_close( + advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close( + value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4 + ) + + def test_vec_td_lambda_return_estimate_fails_fullgraph(self, value_data): + """Test that vec_td_lambda_return_estimate fails with fullgraph=True due to data-dependent shapes.""" + compiled_fn = torch.compile( + vec_td_lambda_return_estimate, + fullgraph=True, + backend="inductor", + ) + + # This should fail because of data-dependent shapes in _get_num_per_traj + with pytest.raises(Exception): + compiled_fn( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + def test_vec_generalized_advantage_estimate_fails_fullgraph(self, value_data): + """Test that vec_generalized_advantage_estimate fails with fullgraph=True due to data-dependent shapes.""" + compiled_fn = torch.compile( + vec_generalized_advantage_estimate, + fullgraph=True, + backend="inductor", + ) + + # This should fail because of data-dependent shapes in _get_num_per_traj + with pytest.raises(Exception): + compiled_fn( + gamma=value_data["gamma"], + lmbda=value_data["lmbda"], + state_value=value_data["state_value"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + def test_td_lambda_with_tensor_gamma_compiles_fullgraph(self, value_data): + """Test that td_lambda_return_estimate compiles with 0-d tensor gamma (fullgraph=True). + + This tests the fix for PendingUnbackedSymbolNotFound error that occurred when + torch.full_like received a 0-d tensor and internally called .item(). + """ + # Use 0-d tensor gamma/lmbda - this was the problematic case + gamma_tensor = torch.tensor(value_data["gamma"]) + lmbda_tensor = torch.tensor(value_data["lmbda"]) + + result_eager = td_lambda_return_estimate( + gamma=gamma_tensor, + lmbda=lmbda_tensor, + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + compiled_fn = torch.compile( + td_lambda_return_estimate, + fullgraph=True, + backend="inductor", + ) + + result_compiled = compiled_fn( + gamma=gamma_tensor, + lmbda=lmbda_tensor, + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4) + + def test_gae_with_tensor_gamma_compiles_fullgraph(self, value_data): + """Test that generalized_advantage_estimate compiles with 0-d tensor gamma (fullgraph=True). + + This tests the fix for PendingUnbackedSymbolNotFound error that occurred when + torch.full_like received a 0-d tensor and internally called .item(). + """ + # Use 0-d tensor gamma/lmbda - this was the problematic case + gamma_tensor = torch.tensor(value_data["gamma"]) + lmbda_tensor = torch.tensor(value_data["lmbda"]) + + advantage_eager, value_target_eager = generalized_advantage_estimate( + gamma=gamma_tensor, + lmbda=lmbda_tensor, + state_value=value_data["state_value"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + compiled_fn = torch.compile( + generalized_advantage_estimate, + fullgraph=True, + backend="inductor", + ) + + advantage_compiled, value_target_compiled = compiled_fn( + gamma=gamma_tensor, + lmbda=lmbda_tensor, + state_value=value_data["state_value"], + next_state_value=value_data["next_state_value"], + reward=value_data["reward"], + done=value_data["done"], + terminated=value_data["terminated"], + ) + + torch.testing.assert_close( + advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close( + value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4 + ) + + +class TestTDLambdaEstimatorCompile: + """Test TDLambdaEstimator compile-friendly vectorized property.""" + + def test_vectorized_property_returns_true_in_eager_mode(self): + """Test that TDLambdaEstimator.vectorized returns True in eager mode when set to True.""" + from tensordict.nn import TensorDictModule + from torch import nn + + from torchrl.objectives.value.advantages import TDLambdaEstimator + + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + estimator = TDLambdaEstimator( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + vectorized=True, + ) + + assert estimator.vectorized is True + assert estimator._vectorized is True + + def test_vectorized_property_returns_false_in_eager_mode_when_set_false(self): + """Test that TDLambdaEstimator.vectorized returns False in eager mode when set to False.""" + from tensordict.nn import TensorDictModule + from torch import nn + + from torchrl.objectives.value.advantages import TDLambdaEstimator + + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + estimator = TDLambdaEstimator( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + vectorized=False, + ) + + assert estimator.vectorized is False + assert estimator._vectorized is False + + def test_vectorized_setter_works(self): + """Test that TDLambdaEstimator.vectorized setter works correctly.""" + from tensordict.nn import TensorDictModule + from torch import nn + + from torchrl.objectives.value.advantages import TDLambdaEstimator + + value_net = TensorDictModule( + nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ) + estimator = TDLambdaEstimator( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + vectorized=True, + ) + + assert estimator.vectorized is True + + estimator.vectorized = False + assert estimator.vectorized is False + assert estimator._vectorized is False + + estimator.vectorized = True + assert estimator.vectorized is True + assert estimator._vectorized is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 1615b8688cd..c10e8a96540 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -333,13 +333,15 @@ def vec_generalized_advantage_estimate( gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True) gammalmbdas = gammalmbdas.cumprod(-2) - first_below_thr = gammalmbdas < 1e-7 - # if we have multiple gammas, we only want to truncate if _all_ of - # the geometric sequences fall below the threshold - first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1) - if first_below_thr.any(): - first_below_thr = torch.where(first_below_thr)[0][0].item() - gammalmbdas = gammalmbdas[..., :first_below_thr, :] + # Skip data-dependent truncation optimization during compile (causes guards) + if not is_dynamo_compiling(): + first_below_thr = gammalmbdas < 1e-7 + # if we have multiple gammas, we only want to truncate if _all_ of + # the geometric sequences fall below the threshold + first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1) + if first_below_thr.any(): + first_below_thr = torch.where(first_below_thr)[0][0].item() + gammalmbdas = gammalmbdas[..., :first_below_thr, :] not_terminated = (~terminated).to(dtype) td0 = reward + not_terminated * gamma * next_state_value - state_value @@ -524,7 +526,14 @@ def td1_return_estimate( single_gamma = False if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape): single_gamma = True - gamma = torch.full_like(next_state_value, gamma) + if isinstance(gamma, torch.Tensor): + # Use expand instead of full_like to avoid .item() call which creates + # unbacked symbols during torch.compile tracing. + if gamma.device != next_state_value.device: + gamma = gamma.to(next_state_value.device) + gamma = gamma.expand(next_state_value.shape) + else: + gamma = torch.full_like(next_state_value, gamma) if rolling_gamma is None: rolling_gamma = True @@ -847,12 +856,26 @@ def td_lambda_return_estimate( single_gamma = False if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape): single_gamma = True - gamma = torch.full_like(next_state_value, gamma) + if isinstance(gamma, torch.Tensor): + # Use expand instead of full_like to avoid .item() call which creates + # unbacked symbols during torch.compile tracing. + if gamma.device != next_state_value.device: + gamma = gamma.to(next_state_value.device) + gamma = gamma.expand(next_state_value.shape) + else: + gamma = torch.full_like(next_state_value, gamma) single_lambda = False if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape): single_lambda = True - lmbda = torch.full_like(next_state_value, lmbda) + if isinstance(lmbda, torch.Tensor): + # Use expand instead of full_like to avoid .item() call which creates + # unbacked symbols during torch.compile tracing. + if lmbda.device != next_state_value.device: + lmbda = lmbda.to(next_state_value.device) + lmbda = lmbda.expand(next_state_value.shape) + else: + lmbda = torch.full_like(next_state_value, lmbda) if rolling_gamma is None: rolling_gamma = True @@ -1004,7 +1027,8 @@ def _fast_td_lambda_return_estimate( # the only valid next states are those where the trajectory does not terminate next_state_value = (~terminated).int() * next_state_value - gamma_tensor = torch.tensor([gamma], device=device) + # Use torch.full to create directly on device (avoids DeviceCopy in cudagraph) + gamma_tensor = torch.full((1,), gamma, device=device) gammalmbda = gamma_tensor * lmbda num_per_traj = _get_num_per_traj(done) @@ -1125,7 +1149,8 @@ def _is_scalar(tensor): if rolling_gamma is None: rolling_gamma = True - if not rolling_gamma: + if not rolling_gamma and not is_dynamo_compiling(): + # Skip this validation during compile to avoid CUDA syncs terminated_follows_terminated = terminated[..., 1:, :][ terminated[..., :-1, :] ].all()